
    gTY                     D   d dl Z d dlZd dlZd dlmZ d dlmZmZmZ d dl	Z
d dlmZmZmZ d dlmZ dededed	ed
e
j$                  f
dZdededed
e
j$                  fdZdedefdZ	 d-dedededededed
e
j$                  fdZdedeee
j$                  f   fdZdededed	ededededededededefdZdededededededededededefdZd Z	 	 	 d.ded ee   d!ee   d"ee   d
eee
j$                     ee
j$                     ee
j$                     f   f
d#Z	 	 	 d.d$ed ee   d!ee   d"ee   d
eee
j$                     ee
j$                     ee
j$                     f   f
d%Zd& Zd'ed(edededededed ee   d!ee   d"ee   d)edededefd*Z d+ Z!e"d,k(  r e!        yy)/    N)Path)DictOptionalTuple)
ModelProtoTensorProtonumpy_helper)	OnnxModel	input_ids
batch_sizesequence_lengthdictionary_sizereturnc                 (   | j                   j                  j                  t        j                  t        j
                  t        j                  fv sJ t        j                  j                  |||ft        j                        }| j                   j                  j                  t        j                  k(  rt        j                  |      }|S | j                   j                  j                  t        j                  k(  rt        j                  |      }|S )a`  Create input tensor based on the graph input of input_ids

    Args:
        input_ids (TensorProto): graph input of the input_ids input tensor
        batch_size (int): batch size
        sequence_length (int): sequence length
        dictionary_size (int): vocabulary size of dictionary

    Returns:
        np.ndarray: the input tensor created
    )sizedtype)typetensor_type	elem_typer   FLOATINT32INT64nprandomrandintint32float32int64)r   r   r   r   datas        \/var/www/openai/venv/lib/python3.12/site-packages/onnxruntime/transformers/bert_test_data.pyfake_input_ids_datar!      s     >>%%//4    99_J3PXZX`X`aD~~!!++{/@/@@zz$ K 
	#	#	-	-1B1B	Bxx~K    segment_idsc                    | j                   j                  j                  t        j                  t        j
                  t        j                  fv sJ t        j                  ||ft        j                        }| j                   j                  j                  t        j                  k(  rt        j                  |      }|S | j                   j                  j                  t        j                  k(  rt        j                  |      }|S )a,  Create input tensor based on the graph input of segment_ids

    Args:
        segment_ids (TensorProto): graph input of the token_type_ids input tensor
        batch_size (int): batch size
        sequence_length (int): sequence length

    Returns:
        np.ndarray: the input tensor created
    r   )r   r   r   r   r   r   r   r   zerosr   r   r   )r#   r   r   r   s       r    fake_segment_ids_datar'   2   s     ''116    88Z1BD##--1B1BBzz$ K 
			%	%	/	/;3D3D	Dxx~Kr"   max_sequence_lengthaverage_sequence_lengthc                     |dk\  r|| k  sJ d|z  | kD  rt        j                  d|z  | z
  |       S t        j                  dd|z  dz
        S )N      )r   r   )r(   r)   s     r    get_random_lengthr-   M   sd    "a',CGZ,ZZZ 	""%88~~a"99<OOQdee~~a%<!<q!@AAr"   
input_maskrandom_sequence_length	mask_typec                    | j                   j                  j                  t        j                  t        j
                  t        j                  fv sJ |dk(  r_t        j                  |t        j                        }|r!t        |      D ]  }t        ||      ||<    nt        |      D ]  }|||<   	 n|dk(  rt        j                  ||ft        j                        }|r5t        |      D ]%  }t        ||      }t        |      D ]	  }	d|||	f<    ' n@t        j                  ||ft        j                        }
|
|d|
j                  d   d|
j                  d   f<   n|dk(  sJ t        j                  |dz  dz   t        j                        }|r{t        |      D ]  }t        ||      ||<    t        |dz         D ]J  }|dkD  r|||z   dz
     ||dz
     z   nd|||z   <   |dkD  r|||z   dz
     ||dz
     z   nd|d|z  dz   |z   <   L nDt        |      D ]  }|||<   	 t        |dz         D ]  }||z  |||z   <   ||z  |d|z  dz   |z   <     | j                   j                  j                  t        j                  k(  rt        j                  |      }|S | j                   j                  j                  t        j                  k(  rt        j                  |      }|S )a"  Create input tensor based on the graph input of segment_ids.

    Args:
        input_mask (TensorProto): graph input of the attention mask input tensor
        batch_size (int): batch size
        sequence_length (int): sequence length
        average_sequence_length (int): average sequence length excluding paddings
        random_sequence_length (bool): whether use uniform random number for sequence length
        mask_type (int): mask type - 1: mask index (sequence length excluding paddings). Shape is (batch_size).
                                     2: 2D attention mask. Shape is (batch_size, sequence_length).
                                     3: key len, cumulated lengths of query and key. Shape is (3 * batch_size + 2).

    Returns:
        np.ndarray: the input tensor created
    r+   r%   r,   Nr      )r   r   r   r   r   r   r   r   onesr   ranger-   r&   shaper   r   )r.   r   r   r)   r/   r0   r   iactual_seq_lenjtemps              r    fake_input_mask_datar:   W   s   0 ??&&005    A~ww
2884!:&+O=TUQ ' :&1Q '	axx_5RXXF!:&!2?D[!\~.A!"DAJ / '
 77J(?@QD59D4::a=/DJJqM/12A~~xxa!+BHH=!:&+O=TUQ ' :>*QRUVQVtJNQ,>'?$q1u+'M\]Z!^$YZ]^Y^tJNQ4F/G$qSTu+/UdeQ^a'!+, + :&1Q ':>*'(+B'BZ!^$/03J/JQ^a'!+, + "",,0A0AAzz$ K 
	$	$	.	.+2C2C	Cxx~Kr"   	directoryinputsc           	          t         j                  j                  |       s&	 t        j                  |        t	        d|  d       nt	        d|  d       t        |j                               D ]t  \  }\  }}t        j                  ||      }t        t         j                  j                  | d| d      d	      5 }|j                  |j                                d
d
d
       v y
# t
        $ r t	        d|  d       Y w xY w# 1 sw Y   xY w)zOutput input tensors of test data to a directory

    Args:
        directory (str): path of a directory
        inputs (Dict[str, np.ndarray]): map from input name to value
    z#Successfully created the directory  zCreation of the directory z failedzWarning: directory z$ existed. Files will be overwritten.input_.pbwbN)ospathexistsmkdirprintOSError	enumerateitemsr	   
from_arrayopenjoinwriteSerializeToString)r;   r<   indexnamer   tensorfiles          r    output_test_datarS      s     77>>)$	FHHY 7	{!DE#I;.RST(8|d((t4"'',,yF5'*=>EJJv//12 FE  9  	C.ykAB	C FEs   C& ; D&D DD	
test_casesverboserandom_seedc           	         |J t         j                  j                  |       t        j                  |       g }t        |      D ]  }t	        || ||      }|j
                  |i}|rt        || |      ||j
                  <   |rt        || ||	|
|      ||j
                  <   |rt        |      dk(  rt        d|       |j                  |        |S )a  Create given number of input data for testing

    Args:
        batch_size (int): batch size
        sequence_length (int): sequence length
        test_cases (int): number of test cases
        dictionary_size (int): vocabulary size of dictionary for input_ids
        verbose (bool): print more information or not
        random_seed (int): random seed
        input_ids (TensorProto): graph input of input IDs
        segment_ids (TensorProto): graph input of token type IDs
        input_mask (TensorProto): graph input of attention mask
        average_sequence_length (int): average sequence length excluding paddings
        random_sequence_length (bool): whether use uniform random number for sequence length
        mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key

    Returns:
        List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
                                       with input name as key and a tensor as value
    r   zExample inputs)r   r   seedr4   r!   rP   r'   r:   lenrF   append)r   r   rT   r   rU   rV   r   r#   r.   r)   r/   r0   
all_inputs
_test_caseinput_1r<   s                   r    fake_test_datar^      s    D    IINN;
KKJJ'
%i_o^..'*'<[*Ve'fF;##$&:J9PRhjs'F:??# s:!+"F+&! ( r"   rX   c                 h    d}t        | ||||||||||	|
      }t        |      |k7  rt        d       |S )a  Create given number of input data for testing

    Args:
        batch_size (int): batch size
        sequence_length (int): sequence length
        test_cases (int): number of test cases
        seed (int): random seed
        verbose (bool): print more information or not
        input_ids (TensorProto): graph input of input IDs
        segment_ids (TensorProto): graph input of token type IDs
        input_mask (TensorProto): graph input of attention mask
        average_sequence_length (int): average sequence length excluding paddings
        random_sequence_length (bool): whether use uniform random number for sequence length
        mask_type (int): mask type 1 is mask index; 2 is 2D mask; 3 is key len, cumulated lengths of query and key

    Returns:
        List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
                                       with input name as key and a tensor as value
    i'  z$Failed to create test data for test.)r^   rY   rF   )r   r   rT   rX   rU   r   r#   r.   r)   r/   r0   r   r[   s                r    generate_test_datar`      sU    @ OJ :*$45r"   c                     |t        |j                        k\  ry |j                  |   }| j                  |      }|A| j                  ||      }|-|j                  dk(  r| j                  |j                  d         }|S )NCastr   )rY   inputfind_graph_input
get_parentop_type)
onnx_model
embed_nodeinput_indexrc   graph_inputparent_nodes         r    get_graph_input_from_embed_noderl   %  s    c***++[)E--e4K ++JD"{':':f'D$55k6G6G6JKKr"   rg   input_ids_namesegment_ids_nameinput_mask_namec                 p   | j                         }|| j                  |      }|t        d|       d}|r!| j                  |      }|t        d|       d}|r!| j                  |      }|t        d|       d|rdndz   |rdndz   }t        |      |k7  rt        d| dt        |             |||fS t        |      dk7  rt        dt        |             | j	                  d	      }	t        |	      dk(  rh|	d   }
t        | |
d      }t        | |
d      }t        | |
d
      }|(|D ]#  }|j                  j                         }d|v s"|}% |t        d      |||fS d}d}d}|D ]0  }|j                  j                         }d|v r|}$d|v sd|v r|}/|}2 |r	|r|r|||fS t        d      )a  Find graph inputs for BERT model.
    First, we will deduce inputs from EmbedLayerNormalization node.
    If not found, we will guess the meaning of graph inputs based on naming.

    Args:
        onnx_model (OnnxModel): onnx model object
        input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
        segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
        input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.

    Raises:
        ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
        ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
                    and input_mask_name

    Returns:
        Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
                                                                                 segment_ids and input_mask
    Nz Graph does not have input named r+   r   zExpect the graph to have z inputs. Got r2   z'Expect the graph to have 3 inputs. Got EmbedLayerNormalization   maskz#Failed to find attention mask inputtokensegmentz?Fail to assign 3 inputs. You might try rename the graph inputs.)'get_graph_inputs_excluding_initializersrd   
ValueErrorrY   get_nodes_by_op_typerl   rP   lower)rg   rm   rn   ro   graph_inputsr   r#   r.   expected_inputsembed_nodesrh   rc   input_name_lowers                r    find_bert_inputsr~   2  s>   4 EEGL!//?	??OPQQ$556FGK" #CDTCU!VWW
#44_EJ! #COCT!UVVKqQ7
1PQR|/88IWZ[gWhVijkk+z11
<AB3|CTBUVWW112KLK
;1 ^
3J
AN	5j*aP4ZQO
%#(::#3#3#5 --!&J & BCC+z11 IKJ ::++-%%J''98H+HKI  [Z+z11
V
WWr"   	onnx_filec                     t               }t        | d      5 }|j                  |j                                ddd       t	        |      }t        ||||      S # 1 sw Y   "xY w)a  Find graph inputs for BERT model.
    First, we will deduce inputs from EmbedLayerNormalization node.
    If not found, we will guess the meaning of graph inputs based on naming.

    Args:
        onnx_file (str): onnx model path
        input_ids_name (str, optional): Name of graph input for input IDs. Defaults to None.
        segment_ids_name (str, optional): Name of graph input for segment IDs. Defaults to None.
        input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.

    Returns:
        Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
                                                                                 segment_ids and input_mask
    rbN)r   rK   ParseFromStringreadr
   r~   )r   rm   rn   ro   modelrR   rg   s          r    get_bert_inputsr     sW    ( LE	i	$diik* 
 5!JJ8H/ZZ	 
	s    AA!c                  t   t        j                         } | j                  ddt        d       | j                  ddt        d d       | j                  d	dt        d
d       | j                  ddt        dd       | j                  ddt        d d       | j                  ddt        d d       | j                  ddt        d d       | j                  ddt        d
d       | j                  ddt        dd       | j                  dddd       | j                  d       | j                  dddd        | j                  d!       | j                  d"d#d$t        d%&       | j                  d'd(ddd)       | j                  d*       | j                  d+dt        d,d-       | j                         }|S ).Nz--modelTzbert onnx model path.)requiredr   helpz--output_dirFz4output test data path. Default is current directory.)r   r   defaultr   z--batch_sizer+   zbatch size of inputz--sequence_length   z maximum sequence length of inputz--input_ids_namezinput name for input idsz--segment_ids_namezinput name for segment idsz--input_mask_namezinput name for attention maskz	--samplesz$number of test cases to be generatedz--seedr2   zrandom seedz	--verbose
store_truezprint verbose information)r   actionr   )rU   z--only_input_tensorsz-only save input tensors and no output tensors)only_input_tensorsz-az--average_sequence_lengthz)average sequence length excluding padding)r   r   r   z-rz--random_sequence_lengthz3use uniform random instead of fixed sequence length)r/   z--mask_typer,   z^mask type: (1: mask index, 2: raw 2D mask, 3: key lengths, cumulated lengths of query and key))argparseArgumentParseradd_argumentstrintset_defaults
parse_args)parserargss     r    parse_argumentsr     s'   $$&F
	DsAXY
C   S!Rgh
/   '   )   ,   3   5sAMZ
(	   &
<	   51
#8   "B   u5
m   DKr"   r   
output_dirr   c                    t        | |||	      \  }}}t        |||||||||||      }t        |      D ]=  \  }}t        j                  j                  |dt        |      z         }t        ||       ? |
ryddl}d|j                         v rddgndg}|j                  | |      }|j                         D cg c]  }|j                   }}t        |      D ]  \  }}t        j                  j                  |dt        |      z         }|j                  ||      }t        |      D ]  \  }}t        j                  t!        j"                  ||         |      }t%        t        j                  j                  |d| d      d	      5 }|j'                  |j)                                ddd         yc c}w # 1 sw Y   xY w)
aI  Create test data for a model, and save test data to a directory.

    Args:
        model (str): path of ONNX bert model
        output_dir (str): output directory
        batch_size (int): batch size
        sequence_length (int): sequence length
        test_cases (int): number of test cases
        seed (int): random seed
        verbose (bool): whether print more information
        input_ids_name (str): graph input name of input_ids
        segment_ids_name (str): graph input name of segment_ids
        input_mask_name (str): graph input name of input_mask
        only_input_tensors (bool): only save input tensors,
        average_sequence_length (int): average sequence length excluding paddings
        random_sequence_length (bool): whether use uniform random number for sequence length
        mask_type(int): mask type
    test_data_set_Nr   CUDAExecutionProviderCPUExecutionProvider)	providersoutput_r@   rA   )r   r`   rH   rB   rC   rL   r   rS   onnxruntimeget_available_providersInferenceSessionget_outputsrP   runr	   rJ   r   asarrayrK   rM   rN   )r   r   r   r   rT   rX   rU   rm   rn   ro   r   r)   r/   r0   r   r#   r.   r[   r6   r<   r;   r   r   sessionoutputoutput_namesresultoutput_nametensor_resultrR   s                                 r    create_and_save_test_datar     s   D *9P`bq)r&I{J#J z*	6GGLL-=A-FG	F+ +  #k&I&I&KK 
!"89$% 
 **5I*FG.5.A.A.CD.CFFKK.CLDz*	6GGLL-=A-FG	\62'5NA{(33BJJvay4I;WMbggll9s#.>?F$

=::<= GF 6 + E GFs   8F4 F99Gc                     t               } | j                  dk  r| j                  | _        | j                  }|Yt	        | j
                        }t        j                  j                  |j                  d| j                   d| j                         }|t	        |      }|j                  dd       nt        d       t        | j
                  || j                  | j                  | j                  | j                  | j                   | j"                  | j$                  | j&                  | j(                  | j                  | j*                  | j,                         t        d|       y )Nr   batch__seq_T)parentsexist_okz7Directory existed. test data files will be overwritten.z Test data is saved to directory:)r   r)   r   r   r   r   rB   rC   rL   parentr   rE   rF   r   samplesrX   rU   rm   rn   ro   r   r/   r0   )r   r   prC   s       r    mainr   Z  s   D##q('+';';$JWW\\!((fT__4EU4K_K_J`,ab
J

4$
/GH

		$$##" 

,j9r"   __main__)r,   )NNN)#r   rB   r   pathlibr   typingr   r   r   numpyr   onnxr   r   r	   rg   r
   r   ndarrayr!   r'   r-   boolr:   r   rS   r^   r`   rl   r~   r   r   r   r   __name__ r"   r    <module>r      s    	   ( (  6 6  (+>ATWZZ<{  VY ^`^h^h 6B3 B B  FFF F !	F
 !F F ZZFR3 3T#rzz/-B 3.777 7 	7
 7 7 7 7 7 !7 !7 7t111 1 	1
 1 1 1 1 !1 !1 1h
 %)&*%)	YXYXSMYX smYX c]	YX
 8BJJ"**!5x

7KKLYX| %)&*%)	[[SM[ sm[ c]	[
 8BJJ"**!5x

7KKL[8aHI>I>I> I> 	I>
 I> I> I> SMI> smI> c]I> I> !I> !I> I>X$:N zF r"   