
    g                        d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZ d dlmZ d dlmZmZmZmZmZmZmZmZmZ d dlmZmZ d dlmZ  ej<                  e      Z  G d	 d
      Z!y)    N)deque)Path)DictListOptionalTuple)convert_float_to_float16)	AttributeProto
GraphProto
ModelProto	NodeProtoTensorProtoValueInfoProtohelpernumpy_helper
save_model)load_external_data_for_tensoruses_external_data)SymbolicShapeInferenceHelperc                   4   e Zd Zd Zd Zd Zi dfdZdedZdedZd Z	d	 Z
d
 Zd Zd Zd Zd Zd Zd Zd Zd ZdfdZdfdZdfdZdfdZed        Zd Zed        Zd Zd Zd ZdfdZdfdZ dfd Z!g fd!Z"ddg dfd"Z#d# Z$d$ Z%	 	 	 dgd%Z&dhd&Z'ddg fd'Z(dhd(Z)d) Z*d* Z+did+Z,d, Z-did-Z.dfd.Z/d/ Z0dfd0e1d1e2e3   fd2Z4dfd0e1d1e2e3   fd3Z5ed4e6d5e1fd6       Z7d7 Z8d8 Z9djd9Z:djd:Z;dfd;Z<d< Z=d= Z>dfd>Z?ded?Z@ed@        ZAdA ZBdB ZCdC ZDdhdDZEdkdEZFdF ZGededG       ZHdedHZIe	 	 	 	 dldI       ZJdmdJZKdK ZLdL ZMdedMZNedndNeOdOe1dPePfdQ       ZQe	 	 dodReOdSeOdTe2eR   dUe2eR   dPeSf
dV       ZTdfdWe2eR   fdXZUdYe1fdZZVd[ ZWd\ ZXd]eYd^ePfd_ZZd`eYd^ePfdaZ[dbe1dce1fddZ\y)p	OnnxModelc                 &    | j                  |       y N)
initializeselfmodels     X/var/www/openai/venv/lib/python3.12/site-packages/onnxruntime/transformers/onnx_model.py__init__zOnnxModel.__init__!   s        c                 f    || _         i | _        d | _        d| _        d | _        d | _        d | _        y NT)r   _node_name_suffixshape_infer_helperenable_shape_infer
all_graphs_dtype_dict_shape_dictr   s     r   r   zOnnxModel.initialize$   s:    !&
13@D(,6:
 6:6:r    c                     d| _         y )NF)r%   r   s    r   disable_shape_inferencez!OnnxModel.disable_shape_inference1   s
    "'r    Fc                 .   | j                   rR| j                  |rt        | j                        | _        	 | j                  j	                  |      r| j                  S 	 y y # t
        $ r+ d| _         t        dt        j                         d          Y y w xY w)NFzfailed in shape inferencer   )	r%   r$   r   r   infer	Exceptionprintsysexc_info)r   dynamic_axis_mappingupdates      r   infer_runtime_shapezOnnxModel.infer_runtime_shape4   s    ""&&.&*Ftzz*R'F**001EF222 G 	  F*/'13<<>!3DE	Fs   &A   1BBc                     i }|s| j                         n| j                  j                  j                  }|D ]5  }|j                  D ]$  }|s||vr|g||<   ||   j                  |       & 7 |S r   )nodesr   graphnodeinputappend)r   exclude_subgraphsinput_name_to_nodesnodes_to_searchr8   
input_names         r   r<   zOnnxModel.input_name_to_nodesB   ss     .?$**,TZZEUEUEZEZ#D"jj
!)<<;?&+J7+J7>>tD ) $ #"r    c                     i }|s| j                         n| j                  j                  j                  }|D ]  }|j                  D ]
  }|s|||<     |S r   )r6   r   r7   r8   output)r   r;   output_name_to_noder=   r8   output_names         r   rA   zOnnxModel.output_name_to_nodeN   sU     .?$**,TZZEUEUEZEZ#D#{{7;'4  + $ #"r    c                 F    t        | j                  j                        g}|S r   )listr   	functions)r   all_functionss     r   rE   zOnnxModel.functionsW   s    djj2234r    c                 x    g }| j                         D ]$  }|j                  D ]  }|j                  |        & |S r   )graphsr8   r:   )r   	all_nodesr7   r8   s       r   r6   zOnnxModel.nodes[   s;    	[[]E

  & # # r    c                 .    | j                   j                  S r   )r   r7   r*   s    r   r7   zOnnxModel.graphb   s    zzr    c                    | j                   | j                   S g | _         | j                  j                  g}|r|j                  d      }| j                   j	                  |       |j
                  D ]  }|j                  D ]  }|j                  t        j                  j                  k(  r7t        |j                  t              sJ |j	                  |j                         |j                  t        j                  j                  k(  s|j                  D ]%  }t        |t              sJ |j	                  |       '   |r| j                   S Nr   )r&   r   r7   popr:   r8   	attributetyper
   AttributeTypeGRAPH
isinstancegr   GRAPHSrH   )r   graph_queuer7   r8   attrrS   s         r   rH   zOnnxModel.graphse   s   ??&??"zz''(OOA&EOO""5)

 NNDyyN$@$@$F$FF)$&&*===#**4662yyN$@$@$G$GG!%A#-a#<<#<'..q1 "- + #  r    c                     g }| j                         D ].  }|j                  D ]  }|j                  |j                          0 |S r   )rH   r9   r:   name)r   input_namesr7   r9   s       r   get_graphs_input_namesz OnnxModel.get_graphs_input_namesx   s?    [[]E""5::. % # r    c                     g }| j                         D ].  }|j                  D ]  }|j                  |j                          0 |S r   )rH   r@   r:   rX   )r   output_namesr7   r@   s       r   get_graphs_output_namesz!OnnxModel.get_graphs_output_names   s?    [[]E,,##FKK0 ' # r    c                 R    | j                         D ]  }||j                  v s|c S  y r   )rH   r8   r   r8   r7   s      r   get_graph_by_nodezOnnxModel.get_graph_by_node   s(    [[]Euzz! # r    c                 T    | j                         D ]  }||j                  k(  s|c S  y r   )rH   rX   )r   
graph_namer7   s      r   get_graph_by_namezOnnxModel.get_graph_by_name   s(    [[]EUZZ' # r    c                     t        |j                        D ]   \  }}|j                  D ]  }||v s|c c S  " t        |j                        S r   )	enumerater8   r9   len)r   r7   outputsidxr8   r9   s         r   get_topological_insert_idz#OnnxModel.get_topological_insert_id   sD    "5::.ICG#J $ / 5::r    c                     | j                         D ]-  }||j                  v s|j                  j                  |        y  t        j	                  d|       y )NzFailed to remove node %s)rH   r8   removeloggerwarningr_   s      r   remove_nodezOnnxModel.remove_node   sD    [[]Euzz!

!!$' # 	148r    c                 4    |D ]  }| j                  |        y r   )rn   )r   nodes_to_remover8   s      r   remove_nodeszOnnxModel.remove_nodes   s    #DT" $r    Nc                 B   |#|| j                   j                  j                  k(  r1| j                   j                  j                  j	                  |g       y | j                  |      }| j                  ||j                        }|j                  j                  ||       y r   )	r   r7   rX   r8   extendrc   ri   r@   insert)r   r8   rb   r7   
insert_idxs        r   add_nodezOnnxModel.add_node   sy    tzz/?/?/D/D!DJJ!!(($0**:6E77t{{KJJJj$/r    c                     |0| j                   j                  j                  j                  |       y |D ]#  }||j                     }| j                  ||       % y r   )r   r7   r8   rs   rX   rv   )r   nodes_to_addnode_name_to_graph_namer8   rb   s        r   	add_nodeszOnnxModel.add_nodes   sL    "*JJ!!((6$4TYY?
dJ/ %r    c                 
   |#|| j                   j                  j                  k(  r1| j                   j                  j                  j	                  |g       y | j                  |      }|j                  j	                  |g       y r   )r   r7   rX   initializerrs   rc   )r   tensorrb   r7   s       r   add_initializerzOnnxModel.add_initializer   se    tzz/?/?/D/D!DJJ((//9**:6E$$fX.r    c                 
   |#|| j                   j                  j                  k(  r1| j                   j                  j                  j	                  |g       y | j                  |      }|j                  j	                  |g       y r   )r   r7   rX   r9   rs   rc   )r   r9   rb   r7   s       r   	add_inputzOnnxModel.add_input   sc    tzz/?/?/D/D!DJJ""))5'2**:6EKKw'r    c                     t        |t              rt        |t              sJ t        t        | j                              D ]$  }| j                  |   |k(  s|| j                  |<   & y r   )rR   strrangerf   r9   )r8   old_input_namenew_input_namejs       r   replace_node_inputzOnnxModel.replace_node_input   sO    .#.:nc3RRRs4::'Azz!}. .

1 (r    c                 |    | j                   j                  j                  D ]  }t        j	                  |||        y r   )r   r7   r8   r   r   )r   r   r   r8   s       r   replace_input_of_all_nodesz$OnnxModel.replace_input_of_all_nodes   s.    JJ$$))D((~~N *r    c                     t        |t              rt        |t              sJ t        t        | j                              D ]$  }| j                  |   |k(  s|| j                  |<   & y r   )rR   r   r   rf   r@   )r8   old_output_namenew_output_namer   s       r   replace_node_outputzOnnxModel.replace_node_output   sQ    /3/JPS4TTTs4;;'(A{{1~0!0A )r    c                 |    | j                   j                  j                  D ]  }t        j	                  |||        y r   )r   r7   r8   r   r   )r   r   r   r8   s       r   replace_output_of_all_nodesz%OnnxModel.replace_output_of_all_nodes   s0     JJ$$))D))$Q *r    c                 z    | j                         D ](  }|j                  D ]  }|j                  |k(  s|c c S  * y r   )rH   r|   rX   )r   rX   r7   r}   s       r   get_initializerzOnnxModel.get_initializer   s8    [[]E++;;$&!M , # r    c                 v    g }| j                         D ]#  }|j                  |k(  s|j                  |       % |S r   )r6   op_typer:   )r   r   r6   r8   s       r   get_nodes_by_op_typezOnnxModel.get_nodes_by_op_type   s6    JJLD||w&T" ! r    c                     || j                         }g }|j                  D ]"  }||v s||   D ]  }|j                  |        $ |S r   )r<   r@   r:   )r   r8   r<   childrenr@   s        r   get_childrenzOnnxModel.get_children   sU    &"&":":"<kkF,,/7DOOD) 8 " r    c                     || j                         }g }|j                  D ]  }||v s|j                  ||           |S r   )rA   r9   r:   )r   r8   rA   parentsr9   s        r   get_parentszOnnxModel.get_parents   sL    &"&":":"<ZZE++259:   r    c                     || j                         }t        |j                        |k  ry |j                  |   }||vry ||   S r   )rA   rf   r9   )r   r8   irA   r9   s        r   
get_parentzOnnxModel.get_parent  sN    &"&":":"<tzz?a

1++"5))r    c                     t        |j                        D ]M  \  }}||v s||   }|j                  |k(  r
||vr||fc S t        j	                  d| d|j                          O y)a  
        Find parent node based on constraints on op_type.

        Args:
            node (str): current node name.
            parent_op_type (str): constraint of parent node op_type.
            output_name_to_node (dict): dictionary with output name as key, and node as value.
            exclude (list): list of nodes that are excluded (not allowed to match as parent).

        Returns:
            parent: The matched parent node. None if not found.
            index: The input index of matched parent node. None if not found.
        zTo find first z
, current NN)re   r9   r   rl   debug)r   r8   parent_op_typerA   excluder   r9   parents           r   match_first_parentzOnnxModel.match_first_parent  so     "$**-HAu++,U3>>^3g8M!19$LL>.1AFNNK[!\] . r    c                    |J ||dk\  sJ || j                         }|,| j                  ||||      \  }}||j                  |       |S |t        |j                        k\  r/t
        j                  d| dt        |j                                y| j                  |||      }||j                  |k(  r||vr|S |%t
        j                  d| d|j                          y)a*  
        Find parent node based on constraints on op_type and index.
        When input_index is None, we will find the first parent node based on constraints,
        and return_indice will be appended the corresponding input index.

        Args:
            node (str): current node name.
            parent_op_type (str): constraint of parent node op_type.
            input_index (int or None): only check the parent given input index of current node.
            output_name_to_node (dict): dictionary with output name as key, and node as value.
            exclude (list): list of nodes that are excluded (not allowed to match as parent).
            return_indice (list): a list to append the input index when input_index is None.

        Returns:
            parent: The matched parent node.
        Nr   zinput_index z >= node inputs zExpect z, Got )	rA   r   r:   rf   r9   rl   r   r   r   )	r   r8   r   input_indexrA   r   return_indicer   indexs	            r   match_parentzOnnxModel.match_parent%  s   2 "kQ&666&"&":":"< 33D.J]_fgMFE($$U+M#djj/)LL<}4DS_DUVW{4GH&..N"BvU\G\MLL7>"2&8HIJr    c                     t        |      D ]C  \  }}t        |t        t        f      sJ g }| j	                  ||d   |d   ||      }|s>|||fc S  y)Nr      )NN)re   rR   r   r   match_parent_path)r   r8   pathsrA   r   pathr   matcheds           r   match_parent_pathszOnnxModel.match_parent_pathsW  se     'GAtdT5M222M,,T47DGEXZghG'=00 ( r    c                    g g g }}}t        |      D ]p  \  }}t        |t        t        f      sJ g }	| j	                  ||d   |d   ||	      }
|
s>|j                  |       |j                  |
       |j                  |	       r |||fS )Nr   r   )re   rR   r   r   r   r:   )r   r8   r   rA   match_imatchesreturn_indicesr   r   r   r   s              r   match_parent_paths_allz OnnxModel.match_parent_paths_all`  s    +-r2. 'GAtdT5M222M,,T47DGEXZghGq!w'%%m4 ( //r    c           	         |t        |      t        |      k(  sJ || j                         }|}g }t        |      D ]~  \  }}	| j                  ||	|||   nd|g |      }
|
F|%t        j                  d| d||    d|	 d        yt        j                  d| d|	 d        y|j                  |
       |
} |S )aJ  
        Find a sequence of input edges based on constraints on parent op_type and index.
        When input_index is None, we will find the first parent node based on constraints,
        and return_indice will be appended the corresponding input index.

        Args:
            node (str): current node name.
            parent_op_types (str): constraint of parent node op_type of each input edge.
            parent_input_index (list): constraint of input index of each input edge. None means no constraint.
            output_name_to_node (dict): dictionary with output name as key, and node as value.
            return_indice (list): a list to append the input index
                                  When there is no constraint on input index of an edge.

        Returns:
            parents: a list of matched parent node.
        N)r   r   Failed to match index=z parent_input_index=	 op_type=T
stack_info)rf   rA   re   r   rl   r   r:   )r   r8   parent_op_typesparent_input_indexrA   r   current_nodematched_parentsr   r   matched_parents              r   r   zOnnxModel.match_parent_pathl  s   0 ))*c/.BBBB&"&":":"<#O4JAw!..);)G"1%T#+ / N %%1LL03GHZ[\H]G^^ghogpq#' !   LL#9!IgY!O\`La"">2)L) 5, r    c                    | j                  ||      }t        |      }t        |      dkD  r\|j                         }|j                  |k(  r|S |r*| j                  ||      }|D ]  }|j                  |        t        |      dkD  r\y rL   )r   r   rf   rM   r   
appendleft)	r   r8   
child_typer<   	recursiver   dqr   childs	            r   find_first_child_by_typez"OnnxModel.find_first_child_by_type  s    $$T+>?8_"gk668L##z1##,,\;NO%EMM%( & "gk r    c           
         |t        |      t        |      k(  sJ |}g }t        |      D ]  \  }}	d}
| j                  |      }t        |      D ]L  \  }}|j                  |	k(  s||vs|.||   |k7  r&t        j                  d| d||    d|	 d         y|}
N |
t        j                  d|	 d        y|j                  |
       |
} |S )a  
        Find a sequence of input edges based on constraints on parent op_type and index.
        When input_index is None, we will find the first parent node based on constraints,
        and return_indice will be appended the corresponding input index.

        Args:
            node (str): current node name.
            child_op_types (str): constraint of child node op_type of each input edge.
            child_output_index (list): constraint of input index of each input edge. None means no constraint.
            return_indice (list): a list to append the input index
                                  When there is no constraint on input index of an edge.

        Returns:
            children: a list of matched children node.
        Nr   z child_output_index=r   Tr   zFailed to match child op_type=)rf   re   r   r   rl   r   r:   )r   r8   child_op_typeschild_output_indexr   r   r   matched_childrenr   r   matched_childnode_childrenchild_ir   s                 r   match_child_pathzOnnxModel.match_child_path  s   . ))*c..AAAA#N3JAw M --l;M"+M":==G+W0D)5:LQ:OSZ:Z4QC7KL^_`LaKbbklsktu'+ %   $$)M #; $=gYGTXY##M2(L# 4$  r    c                 6   || j                         }| j                  ||      }t        |      }t        |      dkD  r\|j	                         }|j
                  |k(  r|S |r*| j                  ||      }|D ]  }|j                  |        t        |      dkD  r\y rL   )rA   r   r   rf   rM   r   r   )	r   r8   parent_typerA   r   r   r   r   r   s	            r   find_first_parent_by_typez#OnnxModel.find_first_parent_by_type  s    &"&":":"<""4)<=7^"gk668L##{2##**<9LM%FMM&) & "gk r    c                 ,   | j                  d      D ]X  }|j                  d   |k(  s|j                  D ]4  }|j                  dk(  st	        j
                  |j                        c c S  Z | j                  |      }|t	        j
                  |      S y )NConstantr   value)r   r@   rN   rX   r   to_arraytr   )r   rB   r8   attr|   s        r   get_constant_valuezOnnxModel.get_constant_value  s    --j9D{{1~,>>Cxx7*+44SUU;; * : **;7"((55r    c                 p    t        |j                        D ]  \  }}| j                  |      }|||fc S  y)Nr   )re   r9   r   )r   r8   r   r9   r   s        r   get_constant_inputzOnnxModel.get_constant_input  s=    !$**-HAu++E2E %x .
 r    c                 t    | j                  |      \  }}|"|j                  dk(  rt        ||z
        |k  r|S y)Nr   r   )r   sizeabs)r   r8   expected_valuedeltar   r   s         r   find_constant_inputzOnnxModel.find_constant_input  s@    **405qS9O5PSX5XHr    c           	          | j                  |      }|t        j                  | d| d       yt        |j                        |k7  r+t        j                  | d| d| d|j                          yy)N z is not initializer.Fz shall have z dimensions. Got shape T)r   rl   r   rf   shape)r   rB   
dimensionsdescriptionr   s        r   $is_constant_with_specified_dimensionz.OnnxModel.is_constant_with_specified_dimension  sw    ''4=LLK=+6JKLu{{z)LLK=+l:,Nefkfqfqerstr    c                 .    | j                  |||      dk\  S rL   )r   )r   r8   r   r   s       r   has_constant_inputzOnnxModel.has_constant_input  s    ''neDIIr    c                 \   || j                         }||j                  d      }g }t        |      }t        |      dkD  rl|j	                         }||v r#||vrD|j                  |       |j                  D ]$  }||v s||   }|D ]  }	|j                  |	        & t        |      dkD  rl|S rL   )r<   r@   r   rf   rM   r:   r   )
r   	root_node
stop_nodesr<   r   unique_nodesr   r   r@   r   s
             r   get_children_subgraph_nodesz%OnnxModel.get_children_subgraph_nodes!  s    &"&":":"<&y'7'7':;8_"gk668Lz)</##L1*11F!44#6v#>%-EMM%0 &. 2 "gk r    c                    g }|j                   j                  D ]m  }|j                  d      r|j                  |j                         0|j                  d      r|j                  |j
                         ]|j                  d       o |S )zConvert tensor shape to list	dim_value	dim_param?)r   dimHasFieldr:   r   r   )r   tensor_type
shape_listds       r   tensor_shape_to_listzOnnxModel.tensor_shape_to_list:  so    
""&&Azz+&!!!++.K(!!!++.!!#& ' r    rX   symbolic_shape_helperc                     | j                   i | _         t        j                  | j                  j                  j
                  | j                  j                  j                  | j                  j                  j                        D ]9  }|j                  j                  j                  | j                   |j                  <   ; | j                  j                  j                  D ]>  }|j                  | j                   vs|j                  | j                   |j                  <   @ || j                   v r| j                   |   S |=||j                  v r/|j                  |   }|j                  j                  j                  S y)zXTry get data type given a name (could be initializer, input or output of graph or node).N)r'   	itertoolschainr   r7   
value_infor9   r@   rO   r   	elem_typerX   r|   	data_type	known_vi_)r   rX   r   r   r|   s        r   	get_dtypezOnnxModel.get_dtypeF  s4    #!D'oo

  ++

  &&

  ''

 5?OO4O4O4Y4Y  1  $zz//;;##4+;+;;9D9N9ND$$[%5%56  < 4#####D)) ,9N9X9X1X.88>J??..888r    c                    | j                   i | _         t        j                  | j                  j                  j
                  | j                  j                  j                  | j                  j                  j                        D ]  }|j                  j                  j                  d      s)g }|j                  j                  j                  j                  D ]E  }|j                  r|j                  |j                         +|j                  |j                         G || j                   |j                   <    | j                  j                  j"                  D ]>  }|j                   | j                   vs|j$                  | j                   |j                   <   @ || j                   v r| j                   |   S |=||j&                  v r/|j&                  |   }|j                  j                  j(                  S y)zTTry get shape given a name (could be initializer, input or output of graph or node).Nr   )r(   r   r   r   r7   r   r9   r@   rO   r   r   r   r   r   r:   r   rX   r|   dimsr   r   )r   rX   r   r   r   r   r|   s          r   	get_shapezOnnxModel.get_shape_  s    #!D'oo

  ++

  &&

  ''

 ??..77@E)::@@DD==!LL7!LL7	  E
 9>D$$Z__5  $zz//;;##4+;+;;9D9I9ID$$[%5%56  < 4#####D)) ,9N9X9X1X.88>J??..888r    r8   attribute_namec                 v    | j                   D ]*  }|j                  |k(  st        j                  |      }|c S  y r   )rN   rX   r   get_attribute_value)r8   r  rV   r   s       r   get_node_attributezOnnxModel.get_node_attribute  s6    NNDyyN*2248 # r    c                 Z   | j                         }d}| j                         D ]Z  }|j                  dk(  s| j                  |d|      }|s*|j                  dk(  s:|j                  d   |j                  d<   |dz  }\ |dkD  r't
        j                  d|       | j                          yy)av  Remove Cast node that are followed by another Cast node like  --> Cast --> Cast -->
        Note that this shall be used carefully since it might introduce semantic change.
        For example, float -> int -> float could get different value than the original float value.
        So, it is recommended to used only in post-processing of mixed precision conversion.
        r   Cast)rA   r   zRemoved %d cascaded Cast nodesN)rA   r6   r   r   r9   rl   infoprune_graph)r   rA   removed_countr8   r   s        r   remove_cascaded_cast_nodesz$OnnxModel.remove_cascaded_cast_nodes  s     #668JJLD||v%qFYZfnn6$*LLODJJqM!Q&M ! 1KK8-H r    c                    | j                  d      }| j                  r|t        j                  d       g }| j	                         D ]j  }|j
                  dk(  s| j                  |j                  d   |      }| j                  |j                  d   |      }|sT||k(  sZ|j                  |       l |r2t        | j                               }t        | j                               }|D ]  }t        t        |j                        |z        r{t        t        |j                        |z        sYt        | j                         |j                  d            dk(  r-| j!                  |j                  d   |j                  d          n-| j#                  |j                  d   |j                  d          | j%                  |        t        j'                  dt        |             yy)	zKRemove cast nodes that are not needed: input and output has same data type.T)r3   NzFshape inference failed which might impact useless cast node detection.r
  r   r   z4Removed %d Cast nodes with output type same as input)r4   r%   rl   rm   r6   r   r  r9   r@   r:   setrZ   r]   boolrf   r<   r   r   rn   r  )r   shape_inferrp   r8   input_dtypeoutput_dtypegraph_input_namesgraph_output_namess           r   remove_useless_cast_nodesz#OnnxModel.remove_useless_cast_nodes  s   ..d.;""{':NNcdJJLD||v%"nnTZZ]KH#~~dkk!nkJ;,#>#**40 !  #D$?$?$A B!$T%A%A%C!D'DKK(+==> TZZ3D!DE30024::a=ALL 88ATUW 33DKKNDJJqMR  & ( KKFO$ r    c                 T    t         j                  d       | j                  d|       y )NzbThe function convert_model_float32_to_float16 is deprecated. Use convert_float_to_float16 instead!T)use_symbolic_shape_inferkeep_io_types)rl   rm   r	   )r   cast_input_outputs     r    convert_model_float32_to_float16z*OnnxModel.convert_model_float32_to_float16  s'    p	
 	%%tSd%er    c                    d|vrd|d<   | j                   }|rt        |      }	 |j                  |dd      }|ti }|j                  j                  D ]  }t        |j                  d      st        |j                  j                  d      s;|j                  j                  j                  t        j                  k7  sm|j                  szt               }|j                  |       t        |j                  j                  d      r%|j                  j                  j                  d       |||j                  <    |j                  j                  D ]  }|j                  |v s||j                  =   |j                         D ]'  }|j                  j                  j!                  |       ) d
|i}	|	j)                  dD 
ci c]  }
|
|v r|
||
    c}
       t+        |fi |	}| j-                  |       | j/                          | j1                          y# t"        $ r t$        j'                  d	       Y w xY wc c}
w )a	  Convert a model to half (default) or mixed precision.
           To use mixed precision, user need specify which graph inputs, outputs, operator type
           or list of nodes shall keep in float32.

           Note that the conversion might not proceed without type information for the whole graph.

           By default, we use symbolic shape inference to get type information. The benefit of symbolic shape inference
           is that it could handle fused operators in com.microsoft domain. Those operators cannot be handled in onnx shape
           inference so symbolic shape inference is recommended for optimized model.

           When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled.

           Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable
           symbolic shape inference. If your model is not optimized, you can also use model path to call
           convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to
           avoid the 2GB limit.

        Args:
            use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference.
                                                       Defaults to True.
            keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names.
                                                              If True, model inputs/outputs should be left as float32.
                                                              Defaults to True.
            op_block_list (List[str], optional): List of operator types to leave as float32.
                                                 Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
            node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
            force_fp16_initializers(bool): force converting all float initializers to float16.
                                           Default to false.
            min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
            max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
            force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if
                                                     this script's preference it to keep them in float32.
        r  TF)
auto_mergeguess_output_rankNr   r   r   ziFailed to run symbolic shape inference. Please file an issue in https://github.com/microsoft/onnxruntime.disable_shape_infer)r  min_positive_valmax_finite_valop_block_listnode_block_listforce_fp16_initializersforce_fp16_inputs#use_bfloat16_as_blocked_nodes_dtype)r   r   infer_shapesr7   r   hasattrrO   r   r   r   	UNDEFINEDrX   r   CopyFrom
ClearFieldvaluesr:   r.   rl   rm   r3   r	   r   r  r  )r   r  kwargsr   r$   model_with_shapename_vivivi_copy
parameterskey
fp16_models               r   r	   z"OnnxModel.convert_float_to_float16  s
   D &(&*F?#

# ">e!D#5#B#B5UYmr#B#s 
 $/ G.44??#BGG]; '(;(;[ I " 3 3 = =AVAV V "&4&6G#,,R0&w||'?'?I ' 8 8 C CG L/6GBGG, @ $kk4477g- ' 0 5 &nn...55b9 / ,-EF
		C &= VC[ 		
" .eBzB

#'')&&(9  s8   AH . H 1H H BH 'A	H H=H:9H:c                    |r|j                  d      r|n|dz   }n|dz   }d}|| j                  v r| j                  |   dz   }no| j                         D ]\  }|j                  s|j                  j	                  |      s,	 t        |j                  t        |      d       }t        |dz   |      }^ || j                  |<   |t        |      z   S # t        $ r Y w xY w)ar  Create a unique node name that starts with a prefix (default is operator type).
           The name will not be duplicated with any name that generated or existed in current graphs.
        Args:
            op_type (str): operator type
            name_prefix (str, optional): prefix of node name. Defaults to None.

        Returns:
            str: node name
        _r   r   N)
endswithr#   r6   rX   
startswithintrf   max
ValueErrorr   )r   r   name_prefixprefixsuffixr8   r   s          r   create_node_namezOnnxModel.create_node_name#  s     $/$8$8$=[KRUDUFs]FT+++++F3a7F 

99!5!5f!=! #DIIc&km$< =!$UQY!7	 % *0v&F## & ! !s   0C	CCc                 t    | j                   j                  j                  D ]  }|j                  |k(  s|c S  y r   )r   r7   r9   rX   )r   r>   r9   s      r   find_graph_inputzOnnxModel.find_graph_inputF  s2    ZZ%%++EzzZ' , r    c                 t    | j                   j                  j                  D ]  }|j                  |k(  s|c S  y r   )r   r7   r@   rX   )r   rB   r@   s      r   find_graph_outputzOnnxModel.find_graph_outputL  s2    jj&&--F{{k) . r    c                 J   || j                         }g }| j                  ||      }t        |      }t        |      dkD  rc|j	                         }||v r#||vr;|j                  |       |j                  D ]  }||v s|j                  ||           t        |      dkD  rc|S rL   )rA   r   r   rf   rM   r:   r9   r   )	r   r8   r   rA   r   r   r   r   r9   s	            r   get_parent_subgraph_nodesz#OnnxModel.get_parent_subgraph_nodesR  s    &"&":":"<""4)<=7^"gk668Lz)</##L1)//E 33&9%&@A 0 "gk r    c                 $   g }|j                   D ]*  }| j                  |      s||vs|j                  |       , |rR| j                  |g       }|D ];  }|j                   D ]*  }| j                  |      s||vs|j                  |       , = |S )z@
        Find graph inputs that linked to current node.
        )r9   rB  r:   rF  )r   r   r   graph_inputsr9   parent_nodesr8   s          r   get_graph_inputszOnnxModel.get_graph_inputsh  s     !''E$$U+\0I##E* ( 99,KL$!ZZE,,U3\8Q$++E2 ( % r    c                 P    t        |j                        D ]  \  }}|| k(  s|c S  y)Nr   )re   r9   )node_output
child_noder   r9   s       r   r   zOnnxModel.input_indexy  s-    %j&6&67LE5# 8 r    c                 >   | j                         }g }| j                         }|D ]5  }|j                  dk(  s|j                  d   |vs%|j	                  |       7 | j                  |       t        |      dkD  r"t        j                  dt        |              y y )Nr   r   zRemoved unused constant nodes: )	r<   r6   r   r@   r:   rq   rf   rl   r   )r   r<   unused_nodesr6   r8   s        r   remove_unused_constantz OnnxModel.remove_unused_constant  s    "668 

D||z)dkk!nDW.W##D)  	,'|q LL:3|;L:MNO !r    c                     t               }|j                  D ]X  }|j                  t        j                  k(  s!|j
                  j                  }|D ]  }|j                  |j                          Z |S )zD
        Get inputs to all nodes in all subgraphs of a node
        )	r  rN   rO   r
   rQ   rS   r8   r3   r9   )r   r8   subgraph_nodes_inputsrV   child_nodesrM  s         r   _get_subgraph_inputs_of_nodez&OnnxModel._get_subgraph_inputs_of_node  s_    
 !$NNDyyN000"ffkk"-J)001A1AB #. #
 %$r    c                     t        t        fd| j                  j                  j                              }t               }|D ]$  }| j                  |      }|j                  |       & ||fS )z
        Get input names to all nodes in all subgraphs where subgraphs are
        graph attributes of a node in the main graph
        c                      | j                   v S r   )r   )r8   ops_with_graph_attrss    r   <lambda>z:OnnxModel._get_subgraph_nodes_and_inputs.<locals>.<lambda>  s    $,,BV2Vr    )rD   filterr   r7   r8   r  rT  r3   )r   rW  subgraph_nodesrR  parent_nodesubgraph_inputs_of_parent_nodes    `    r   _get_subgraph_nodes_and_inputsz(OnnxModel._get_subgraph_nodes_and_inputs  sj    
 f%VX\XbXbXhXhXmXmno #)K-1-N-N{-[*!(()GH * 444r    c                 D   |8| j                   j                  j                  D cg c]  }|j                   c}n|}| j	                  d      }| j                         }d }t        | j                               dkD  r| j                  h d      \  }}	t        |      dk(  rt        j                  d	       y| j                   j                  j                  D ](  }
|
|v r|
j                  D ]  }||	v s||vs||gz  } * i }t               }|D ]  }||v s|j                  ||           t        |      dkD  rp|j                         }
 ||
      }|rG||vrC|
||<   |
j                  D ]/  }t        |      dkD  s||v s||vs|j!                  ||          1 t        |      dkD  rpg }d}| j                   j                  j                  D ]R  }
 ||
      }|j#                  |      }|r0|j$                  |
j$                  k(  r||
k(  r|j                  |
       N|dz  }T | j                   j                  j'                  d
       | j                   j                  j                  j)                  |       g }|{| j                   j                  j                  D ]"  }|j                  |vs|j                  |       $ |D ]1  }| j                   j                  j                  j+                  |       3 g }|r| j	                         }| j                   j                  j                  D cg c]  }|j                  |vs| }}|D ]1  }| j                   j                  j                  j+                  |       3 |s|s|dkD  r~g }|r|j                  t        |       d       |r|j                  t        |       d       |dkD  r|j                  | d       t        j-                  ddj/                  |             | j1                          yc c}w c c}w )a  
        Prune graph to keep only required outputs. It removes unnecessary nodes that are not linked
        (directly or indirectly) to any required output.

        There is also an option to remove graph inputs that are not used to generate any required output.

        Args:
            outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept.
            allow_remove_graph_inputs (bool): allow remove graph inputs.
        NT)r;   c                     | j                   d   r| j                   d   S t        t        | j                   D cg c]  }|s|	 c}      d       S c c}w rL   )r@   nextiter)r8   os     r   get_first_outputz/OnnxModel.prune_graph.<locals>.get_first_output  sE    {{1~{{1~%:Aa:;TBB:s
   AAr   >   IfLoopScan)rW  r   z)Skip prune_graph since graph has subgraphr8   z inputsz outputsz nodesz
Removed %sz, )r   r7   r@   rX   r<   rA   rf   rH   r]  rl   r   r8   r   r:   rM   r9   r   getr   r,  rs   rk   r  joinupdate_graph)r   rg   allow_remove_graph_inputsr@   keep_outputs"input_name_to_nodes_for_main_graphrA   rc  rZ  rR  r8   output_to_noder   first_outputrX   nodes_to_keepnum_nodes_removed	kept_nodeoutput_to_removeinput_to_remover<   r9   removeds                          r   r  zOnnxModel.prune_graph  s    OVo$**2B2B2I2IJ2I2IJcj-1-E-EX\-E-]*"668	C
 t{{}!484W4W%; 5X 51N1 >"a'HI 

((-- >) #kkF!666Ik;k$0 * .  W"F,,		-f56 # "gk668D+D1L^!C/3|, JJD4y1}$2E*EDXfLf&9$&?@ ' "gk JJ$$))D+D1L&**<8I Y..$,,>9PTCT$$T*!Q&! * 	

##F+

$$]3 ****11;;g-$++F3 2 +

  ''..v6 + $"&":":"<26**2B2B2H2Hr2HEJJ^qLqu2HOr'

  &&--d3 ( .2Ca2GG#o"6!7w?@#&6"7!8AB 1$"3!4F;<KKdii&89{ K^ ss   P0PPc                 t   | j                   j                  }t               }|j                  D ]]  }|j                  dv r"| j                  |      }|j                  |       |j                  dk7  sC|j                  |j                         _ |rt        j                  d|        g }|rS|j                  D ]"  }|j                  |vs|j                  |       $ |D ]  }|j                  j                  |        |D cg c]  }|j                   }	}t        j                  dt        |       d|	        g }
g }|j                  D ]X  }|j                  |vr-| j                  |j                        s|
j                  |       >|j                  |j                         Z |
D ]  }|j                  j                  |        |
D cg c]  }|j                   }	}t        j                  dt        |
       d|	        |rt        j                  d|        | j!                          y c c}w c c}w )N)re  rf  rd  r   zremaining input names: zremove z unused inputs: z unused initializers: zremaining initializers:)r   r7   r  r8   r   rT  r3   r9   rl   r   rX   r:   rk   rf   r|   rD  rP  )r   verboserj  r7   remaining_input_namesr8   subgraph_inputs_of_nodeinputs_to_remover9   names_to_removeweights_to_removeweights_to_keepr|   s                r   ri  zOnnxModel.update_graph  s   

   #JJD||55*.*K*KD*Q'%,,-DE||z)%,,TZZ8  LL23H2IJK $::%::$++E2 % *""5) * 4DD3C%5::3CDws#3455EoEVWX  ,,K'<<TE[E[\g\l\lEm!((5&&{'7'78	 -
 -K$$[1 - @QQ?P;++?PQws#4566L_L]^_LL2?2CDE##%' E Rs   H0H5c                     |D ]F  }|j                   D ]5  }||v r||v s||   D ]!  }||vst        j                  d||          y 7 H y)Nz<it is not safe to remove nodes since output %s is used by %sFT)r@   rl   r   )r   rp   rk  r<   rA   node_to_removerr  impacted_nodes           r   is_safe_to_fuse_nodeszOnnxModel.is_safe_to_fuse_nodes@  si    -N$2$9$9 #|3#'::)<=M)N(?"LL ^ 0 -
 $) *O %: . r    c                    t               }t               }g }| j                  D cg c]  }|j                   }}| j                  D cg c]  }|j                   }}||z   }	|r|	j	                          |	D ]  }
|j                  |
        d}|s| j                  nt        | j                  d       }d }t        |      t        |      k7  r(t        |      |k(  rnt        |      }t        |      D ]  \  }}||v rt        d |j                  D              }|dk(  rH|j                  |       |j                  |       |j                  D ]  }|s|j                  |        td}|j                  D ]  }
|
s|
|vsd}|j                  } |sH|j                  |       |j                  |       |j                  D ]  }|s|j                  |         t        |      t        |      k7  r(t        |      t        | j                        k7  r0t        dt        |       d	t        | j                         d
|       | j                  d       | j                  j                  |       y c c}w c c}w )Nr   c                     | j                   S r   )rX   )xs    r   rX  z2OnnxModel.graph_topological_sort.<locals>.<lambda>b  s    _`_e_er    r4  c              3   &   K   | ]	  }|sd   yw)r   N ).0r7  s     r   	<genexpr>z3OnnxModel.graph_topological_sort.<locals>.<genexpr>l  s     !=Z1!Zs   r   FTz)Graph is not a DAG: len(sorted_node_set)=z, len(graph.node)=z, failed at node r8   )r  r|   rX   r9   sortaddr8   sortedrf   re   sumr:   r@   RuntimeErrorr,  rs   )r7   is_deterministicdeps_setsorted_node_setsorted_nodesinitinitializer_namesr9   r  rY   r>   sorted_node_set_lengraph_nodeslast_node_namenode_idxr8   input_countr@   faileds                      r   graph_topological_sortz OnnxModel.graph_topological_sortQ  s^   5%383D3DE3D4TYY3DE5:[[A[EUZZ[A'*;;%JLL$ & !(8ejjfUZZUe>f/"c+&66?#'::"%o"6"+K"8$.!!=TZZ!==!# ''-#''1"&++!$LL0 #. "&**J!j&@!%)- #-  ''-#''1"&++!$LL0 #. / #9	 /"c+&66: 3uzz?2;C<P;QQcdghmhrhrdsct  uF  GU  FV  W  	 

,'c FAs   I4I9c                 X    t         j                  | j                  j                  |       y r   )r   r  r   r7   )r   r  s     r   topological_sortzOnnxModel.topological_sort  s     	(()9)9;KLr    c           	         t        |      j                  j                  dd       | j                  D cg c]  }|j                  dk(  s| }}| j
                  j                  D cg c]  }|j                  dk(  s| }	}|	r*|s(| j                  j                         }d|_        d|_        |rt        |      j                  }
|
j                  dd       |dz   }|rt        |      j                  nd }t        j                  j                  |      r-t        j                  d|        t        j                  |       |rMt        j                  j                  |      rRt        j                  d|        t        j                  |       n$t        j                   |
      rt#        d|
 d	      t%        | |d||||
       y t%        | |       y c c}w c c}w )NT)r   exist_okzcom.microsoftr   z.datazDelete the existing onnx file: z(Delete the existing external data file: zOutput directory (z!) for external data is not empty.)save_as_external_dataall_tensors_to_one_filelocationsize_thresholdconvert_attribute)r   r   mkdiropset_importdomainr7   r8   r  versionrX   osr   existsrl   r  rk   listdirr  r   )r   output_pathr  r  r  r  opsetms_opsetr8   ms_node
output_direxternal_data_pathr  s                r   savezOnnxModel.save  s    	[  &&td&C (-'9'9]'9eU\\_=\E'9] %*KK$4$4W$4D8V4$4W8&&**,EEM*EL k*11JTD9!,w!68Ot./44UYHww~~k*=k]KL		+&&77>>"45KK"JK]J^ _`II01::j)&);J<Gh'ijj&*(?!-"3 uk*O ^ Xs   G
G(G=Gc                     t         j                  d       | j                          t        j	                  | j
                  |||       t         j                  d|        y )Nz Sort graphs in topological orderzModel saved to )rl   r  r  r   r  r   )r   r  use_external_data_formatr  s       r   save_model_to_filezOnnxModel.save_model_to_file  sH    67 	tzz;0HJabok]34r    c                     g }| j                   j                  j                  D ]/  }| j                  |j                        |j                  |       1 |S )z[
        Returns real graph inputs (excluding initializers from older onnx model).
        )r   r7   r9   r   rX   r:   )r   rH  r9   s      r   'get_graph_inputs_excluding_initializersz1OnnxModel.get_graph_inputs_excluding_initializers  sN     ZZ%%++E##EJJ/7##E* , r    c                     | j                   j                  D ]  }|j                  dv s|j                  c S  t	        d      )zGet opset version of onnx domain

        Raises:
            RuntimeError: ONNX model has no opset for default domain.

        Returns:
            int: opset version of onnx domain.
        ) zai.onnxz*ONNX model has no opset for default domain)r   r  r  r  r  )r   r  s     r   get_opset_versionzOnnxModel.get_opset_version  s<     ZZ,,E||.}}$ - GHHr    c                    i }| j                         D ]?  }|r|j                  r|j                  dz   nd|j                  z   }||vrdn||   dz   ||<   A t        j	                  dt        |j                         d               |S )z2
        Returns node count of operators.
        :r  r   z
Operators:c                     | d    | d   fS )Nr   r   r  )kvs    r   rX  z3OnnxModel.get_operator_statistics.<locals>.<lambda>  s    "Q%QSTUQVr    r  )r6   r  r   rl   r  r  items)r   include_domainop_countr8   ops        r   get_operator_statisticsz!OnnxModel.get_operator_statistics  s     JJLD'5$++$++#2QUQ]Q]]B "( 21"9IHRL !
 	j(8>W!X YZ[r    r}   base_dirreturnc                 *   | j                  d      rt        d      | j                  t        j                  k(  rt        d      | j                  }t        j                  |      }| j                  t        j                  k(  r't        | |      }t        t        d |D                    S t        |       rt        | |       | j                  d      rt        | j                        S t        j                   |       }t        |j#                               S )a  Converts a tensor def object to a hash for data comparison purposes.
        Args:
            tensor: a TensorProto object.
            base_dir: if external tensor exists, base_dir can help to find the path to it
        Returns:
            hash: a hash of the data.
        segmentz*Currently not supporting loading segments.z4The element type in the input tensor is not defined.c              3   >   K   | ]  }|j                  d         yw)zutf-8N)decode)r  ss     r   r  z)OnnxModel.to_data_hash.<locals>.<genexpr>  s     FAahhw/s   raw_data)r   r<  r   r   r*  	TypeErrorr   tensor_dtype_to_fieldSTRINGgetattrhashtupler   r   r  r   r   tobytes)r}   r  tensor_dtypestorage_fieldutf8_stringsnp_datas         r   to_data_hashzOnnxModel.to_data_hash  s     ??9%IJJ{444RSS''44\B{111"6=9LFFFGGf%)&(;??:&(("++F3G)**r    tensor1tensor2signature_cache1signature_cache2c                    |r| j                   |v r|| j                      nt        j                  |       }|r|j                   |v r||j                      nt        j                  |      }|||| j                   <   ||||j                   <   ||k(  rk| j                  |j                  k(  rR| j                  |j                  k(  r9t        j                  |       t        j                  |      k(  j                         S y)a  Returns True when two tensors have same value.
           Note that name can be different.

        Args:
            tensor1 (TensorProto): initializer 1
            tensor2 (TensorProto): initializer 2
            signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison.
            signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison.
        Returns:
            bool: True when two initializers has same value.
        F)rX   r   r  r   r  r   r   all)r  r  r  r  sig1sig2s         r   has_same_valuezOnnxModel.has_same_value  s    (  GLL4D$D W\\*''0 	  GLL4D$D W\\*''0 	
 '-1W\\*'-1W\\*4<G--1B1BBw||W^WcWcGc ))'2l6K6KG6TTYY[[r    cachec                 D   t        | j                               dkD  rt        j                  d       t        | j                  j
                  j                        }dg|z  }t        |dz
        D ]  }||   dk\  rt        |dz   |      D ]b  }t        j                  | j                  j
                  j                  |   | j                  j
                  j                  |   ||      s^|||<   d  d}t        |      D ]{  }||   dk\  s|dz  }| j                  | j                  j
                  j                  |   j                  | j                  j
                  j                  ||      j                         } |dkD  r | j                          t        d| d       yy)a;  Remove initializers with duplicated values, and only keep the first one.
        It could help reduce size of models (like ALBert) with shared weights.
        If require_raw_data passed, method will only compare raw_data initializers to speed runtime
        Note: this function does not process subgraph.
        r   z9remove_duplicated_initializer does not process subgraphs.r   r   zRemoved z# initializers with duplicated valueN)rf   rH   rl   rm   r   r7   r|   r   r   r  r   rX   ri  r/   )r   r  initializer_countsamer   r   counts          r   remove_duplicated_initializerz'OnnxModel.remove_duplicated_initializer8  sy    t{{}!NNVW

 0 0 < <=t''(1,-AAw!|1q5"34++JJ$$003JJ$$003	  DG 5 . ()AAw!|
//JJ$$00388JJ$$00a9>> * 19HUG#FGH r    r>  c                    t        | j                               dkD  rt        j                  d       | j                  j
                  j                  D cg c]  }|j                   c}| j                  j
                  j                  D cg c]  }|j                   c}z   dgz   }| j                  j
                  j                  D ]7  }|j                  |vs||j                  z   |vs$||j                  z   |_        9 | j                  j
                  j                  D ]  }t        t        |j                              D ]H  }|j                  |   |vs||j                  |   z   |vs*||j                  |   z   |j                  |<   J t        t        |j                              D ]H  }|j                  |   |vs||j                  |   z   |vs*||j                  |   z   |j                  |<   J  | j                  j
                  j                  D ]%  }|j                  |vs||j                  z   |_        ' yc c}w c c}w )zAdd prefix to initializer or intermediate outputs in graph. Main graph inputs and outputs are excluded.
        It could help avoid conflicting in name of node_args when merging two graphs.
        Note: this function does not process subgraph.
        r   z/add_prefix_to_names does not process subgraphs.r  N)rf   rH   rl   rm   r   r7   r9   rX   r@   r|   r8   r   r   )	r   r>  r   rb  excludedr|   r8   r   r   s	            r   add_prefix_to_nameszOnnxModel.add_prefix_to_names]  s   
 t{{}!NNLM %)JJ$4$4$:$:;$:qAFF$:;tzzO_O_OfOf>gOf!qvvOf>ggkmjnn::++77Kx/K,,,H<'-0@0@'@K$ 8
 JJ$$))D3tzz?+::a=0

1-X=(.A(>

1 , 3t{{+,;;q>1A.h>)/$++a.)@A - * ****55Jh."(:??":
 6) <>gs   IIc                 N    | j                   j                  j                  d       y )Nr   )r   r7   r,  r*   s    r   clean_shape_inferzOnnxModel.clean_shape_infer  s    

##L1r    c                 R   g }|j                  | j                  j                         |r}g }|D ]o  }t        |t              st        j                  |j                  |j                  |j                        D ]  }|j                  j                  j                  t        j                  k(  r  y|j                  j                  d      sS|j                  j                   j                  j                  j                  t        j                  k(  s  y |j"                  D ]"  }|j$                  t        j                  k(  s!  y |j&                  D ]H  }|j(                  dk(  rB|j*                  D ]3  }|j,                  dk(  s|j.                  t        j                  k(  s1   y |j*                  D ]  }|j                  t0        j2                  k(  r|j                  |j4                         |j6                  D ]  }|j                  |        t        |j8                  t              r+|j8                  j$                  t        j                  k(  r   y|j:                  D ]5  }t        |t              s|j$                  t        j                  k(  s2    y  K r |}|r}y)z$Check whether the model uses float16Tsequence_typer
  toF)r:   r   r7   rR   r   r   r   r9   r@   r   rO   r   r   r   FLOAT16r   r  r|   r   r8   r   rN   rX   r   r
   rQ   rS   rH   r   tensors)	r   queue
sub_graphsr7   vr   r8   rV   rS   s	            r   use_float16zOnnxModel.use_float16  s   TZZ%%&J!%4"ellEDTDTUAvv))33{7J7JJ#vv766//99EEOOS^SfSff#' V **A{{k&9&99# + "JJD||v-$(NND#yyD0TVV{?R?R5R'+ %3 !%99(<(<<&--dff5!%A&--a0 "- &dffk:tvv?O?OS^SfSf?f#'!%A)![9akk[M`M`>`'+ ". !/ ' F EK N r    graph_inputnew_typec                    t        |t              sJ | j                  |j                        sJ |j                  j
                  j                  t        |      k(  rdg fS | j                         }d}g }| j                         }|j                  |v r||j                     }|D cg c]  }|j                  dk7  s| }	}|	r| j                  d      }
|
dz   |j                  z   }|j                  j                         }|j                  |       ||_        t        j                   d|j                  g|gt        |j                  j
                  j                        |
      }|j"                  j%                  |g       |	D ]#  }t&        j)                  ||j                  |       % |D cg c]  }|j                  dk(  s| }}|D ]}  }t&        j+                  |d      t        |      k(  r)| j-                  |j.                  d   |j                         | j1                  |j.                  d         rm|j3                  |        |r| j5                  |       t        |      |j                  j
                  _        ||fS c c}w c c}w )aq  Change graph input type, and add Cast node if needed.

        Args:
            graph_input (ValueInfoProto): input of the graph
            new_type (int): new data type like TensorProto.INT32.

        Returns:
            NodeProto: a new Cast node that added. None if Cast node is not added.
            List[NodeProto]: Cast nodes that have been removed.
        Nr
  r7  r  rX   r  r   )rR   r   rB  rX   rO   r   r   r:  r7   r<   r   r@  r   r  r+  r   	make_noder8   rs   r   r   r  r   r@   rD  r:   rq   )r   r  r  r7   new_cast_noderp   r<   r6   r8   nodes_not_cast	node_namerB   new_value_info
nodes_casts                 r   change_graph_input_typez!OnnxModel.change_graph_input_type  sK    +~666$$[%5%5666''11S]B8O

"66822'(8(89E 05Out8NduNO 11&9	'#o0@0@@!&!1!1!5!5!7''4&1# & 0 0 %%& M;++77AAB"! 

!!=/2*D00{7G7GU +
 ,1K54DLLF4J$5JK"//d;s8}L33DKKNKDTDTU--dkk!n=#**40	 #
 !!/214X$$.o--? P* Ls   &J;J-J
J
graph_outputc                    t        |t              sJ | j                  |j                        sJ |j                  j
                  j                  t        |      k(  ryd}| j                         }| j                  d      }|dz   |j                  z   }| j                  |j                  |       |j                  j                         }|j                  |       ||_        t        j                  d|g|j                  gt        |      |      }|j                   j#                  |g       t        |      |j                  j
                  _        |S )a!  Change graph input type, and add Cast node if needed.

        Args:
            graph_input (str | ValueInfoProto): output of the graph
            new_type (int): new data type.

        Returns:
            NodeProto: a new Cast node that added. None if Cast node is not added.
        Nr
  r7  r  )rR   r   rD  rX   rO   r   r   r:  r7   r@  r   r   r  r+  r   r  r8   rs   )r   r  r  	cast_noder7   r  r>   r  s           r   change_graph_output_typez"OnnxModel.change_graph_output_type  s'    ,777%%l&7&7888((22c(mC	

 ))&1	_|'8'88
''(9(9:F))--/-($$L8}
	 	

9+&25h-%%/r    old_namenew_namec                 $   || j                         v rt        d      | j                         }|j                  D ]T  }|j                  |k(  st
        j                  d||       | j                  ||       | j                  ||       ||_        V y )Nz{new_name} exists in graphz!replace output name from %s to %s)	rA   r  r7   r@   rX   rl   r   r   r   )r   r  r  r7   r@   s        r   rename_graph_outputzOnnxModel.rename_graph_output  s|    t//11;<<

llF{{h&@(HU//(C008D& #r    )Fr   )NNNr"   )gư>)T)FF)FTi   F)FT)r  r   )]__name__
__module____qualname__r   r   r+   r4   r<   rA   rE   r6   r7   rH   rZ   r]   r`   rc   ri   rn   rq   rv   rz   r~   r   staticmethodr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r   r  r  r  r  r	   r@  rB  rD  rF  rJ  r   rP  rT  r]  r  ri  r  r  r  r  r  r  r  r  r   r:  r  dictr  r  r  r  r  r  r   r  r  r  r  r    r   r   r       s3   ;( 8:% 
## &9#00/( / /O 1 1R	* UW 6  0d
0    6p&  . `$
J2
c (C_:` 2c (C_:` @  C  & Df_)B!$F,"  P%
5iV+&Z" 6( 6(pM  $ $2+ 2+h	5I +[ +C + + +6  ,0+/	### #4.# #4.	#
 
# #J#I8D> #IJ ;#  ;D2+Z=.#=. =.~'$' 'R
'C 
'3 
'r    r   )"r   loggingr  r0   collectionsr   pathlibr   typingr   r   r   r   float16r	   onnxr
   r   r   r   r   r   r   r   r   onnx.external_data_helperr   r   r$   r   	getLoggerr  rl   r   r  r    r   <module>r     sX      	 
   . . ,
 
 
 X ;			8	$A' A'r    