U
    T?hm                    @   s   d dl Z d dlmZmZ d dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlmZ e eZG dd deZG d	d
 d
eZdS )    N)OptionalUnion)FusionAttention)Fusion)FunctionProto	NodeProtoTensorProtohelpernumpy_helper)	OnnxModelc                       sv   e Zd ZdZeeed fddZdeeeeeeeeeeee	e
 eedf ddd	Zd
d Zdd Zdd Z  ZS )FusionRotaryAttentionze
    Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node.
    )modelhidden_size	num_headsc              
      s$   t  j|||ddddddgd d S )NTZSimplifiedLayerNormalization SkipSimplifiedLayerNormalizationZLayerNormalizationSkipLayerNormalizationAdd)Zuse_multi_head_attentionZsearch_op_types)super__init__)selfr   r   r   	__class__ b/var/www/html/venv/lib/python3.8/site-packages/onnxruntime/transformers/fusion_rotary_attention.pyr      s    zFusionRotaryAttention.__init__ N)inputoutputq_rotaryk_rotaryv_matmul	attn_maskadd_qkpast_kpast_v	present_k	present_vscalereturnc                 C   s  | j dkst| jdkrF| j| j  dkrFtd| j d| j   d S | jd}|jd |jd |jd d||||	g}|g}|
r|r||
|g t	j
d|||d}d|_|jt	d| j g |d k	r|jt	d	|g | jd k	r
|jt	d
t| jg | d |S )Nr   z)fuse_rotary_attention: input hidden size z# is not a multiple of num of heads ZMultiHeadAttentionr   inputsoutputsnamecom.microsoftr   r&   mask_filter_value)r   AssertionErrorr   loggerdebugr   create_node_namer   extendr	   	make_nodedomain	attributemake_attributer-   floatincrease_counter)r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   Zmha_node_nameZ
mha_inputsZmha_outputsZmha_noder   r   r   create_mha_node)   sB    
z%FusionRotaryAttention.create_mha_nodec	           1      C   sx  | j |dgdg}	| j |dgdg}
|	d ks8|
d kr<dS |	d |
d  }}| j |dddgdddg}| j |dddgdddg}| j |dddgdddg}| j |dddgdddg}|d ks|d ks|d ks|d krdS |\}}}|\}}}|jd |ks|jd |krdS |d j|jks>|d j|jkrBdS | j |dgdg}| j |dgdg}|d ks~|d krdS |d |d  }}| j |dd	ddgddddg}| j |dd
ddgddddg}| j |dddgdddg}| j |dddgdddg}|d ks4|d ks4|d ks4|d kr8dS |d j|jks|d j|jks|d j|jks|d j|jkrdS | j |dgdg}|d krdS |d }| j |dd	ddgddddg} | j |dd
ddgddddg}!| d ks|!d krdS | d j|jks*|!d j|jkr.dS | j |dgdg}"|"d krPdS |"d }#| j |#dd	ddgddddg}$| j |#dddgdddg}%|$d ks|%d krdS |$d j|jks|%d j|jkrdS |$d }&| d }'|d }(|jd })|&jd |)ks&|'jd |)ks&|(jd |)kr*dS | j |dddgdddg}*| j |ddddgddddg}+|*d k	r||*\}},}-n|+d k	r|+\}}},}-ndS |-jd dkrdS | j |,dd
ddgddddg}.| j |-dd
ddgddddg}/| j |-dgdg}0|.d ks|/d ks|0d kr"dS |.d j|/d jksN|.d j|/d jkrRdS |/d jd |0d jd krtdS dS )NConcat   Fr   	UnsqueezeGatherShape   Mulr   SliceCast>   attention_maskr    T)r   match_parent_pathr   r+   r   )1r   reshape_qkv_2reshape_qkv_1reshape_q_2reshape_k_2reshape_v_2reshape_v_1r!   
root_inputZconcat_qkv_2_pathZconcat_qkv_1_pathZconcat_qkv_2Zconcat_qkv_1Zreshape_qkv_2_path_1Zreshape_qkv_2_path_2Zreshape_qkv_1_path_1Zreshape_qkv_1_path_2_gather_1shape_1gather_2shape_2Zconcat_v_2_pathZconcat_v_1_pathZ
concat_v_2Z
concat_v_1Zreshape_v_2_path_1Zreshape_v_2_path_2Zreshape_v_1_path_1Zreshape_v_1_path_2Zconcat_k_2_pathZ
concat_k_2Zreshape_k_2_path_1Zreshape_k_2_path_2Zconcat_q_2_pathZ
concat_q_2Zreshape_q_2_path_1Zreshape_q_2_path_2Zmul_qZmul_kZmul_vZgather_1_outZattn_mask_path_1Zattn_mask_path_2Z
slice_qk_2Z
slice_qk_1Zslice_qk_2_pathZslice_qk_1_path_1Zslice_qk_1_path_2r   r   r   &check_runtime_shape_paths_for_functiona   s    

 $ 
 
 
 

 
 
 
 
$
 
 
$
0 

 
 
 
 
,z<FusionRotaryAttention.check_runtime_shape_paths_for_functionc                 C   s  | j |dgdg}|d kr dS |d }| j |dddgdddg}| j |dddgdddg}	|d ksp|	d krtdS |\}
}}|	\}
}}|jd |ks|jd |krdS | j |dgdg}|d krdS |d }| j |dddgdddg}| j |dddgdddg}|d ks|d kr dS |d j|jksD|d j|jkrHdS | j |dgdg}|d krjdS |d }| j |dddgdddg}| j |dddgdddg}|d ks|d krdS |d j|jks|d j|jkrdS | j |dgdg}|d krdS |d }| j |dddgdddg}| j |dddgdddg}|d ks`|d krddS |d j|jks|d j|jkrdS dS )	Nr:   r;   Fr   r<   r=   r>   T)r   rD   r   r+   )r   reshape_qkv	reshape_q	reshape_k	reshape_vrK   Zconcat_qkv_pathZ
concat_qkvZreshape_qkv_path_1Zreshape_qkv_path_2rL   rM   rN   rO   rP   concat_v_pathconcat_vZreshape_v_path_1Zreshape_v_path_2concat_k_pathconcat_kZreshape_k_path_1Zreshape_k_path_2Zconcat_q_pathZconcat_qZreshape_q_path_1Zreshape_q_path_2r   r   r   #check_runtime_shape_paths_for_nodes   sV    	

$
$
$z9FusionRotaryAttention.check_runtime_shape_paths_for_nodesc           W         st  |j dkrd S d } j|dddddgdddddg} j|ddddgddddg} j|dddddgdddddg}|d k	r|\}}	}}
}|}nD|d k	r|\}}}}|}n*|d k	r|\}}}}}|}ntd d S d	\}}}d }d } j|ddd
dddgddddddg} j|d
dddgddddg} j|dddgdddg} jj|dddd
dddgdddddddgfdddddd
dddd
dddgdddddddddddddgfddddddddd
dddd
dddgddddddddddddddddgfddddddd
dddd
dddgddddddddddddddgfddddd
dddd
dddgddddddddddddgfdd
dddd
dddg	dddddddddg	fdd
ddddd
dddg
ddddddddddg
fdd
dddd
dddg	dddddddddg	fdd
dddd
dddg	dddddddddg	fg	d d\}}} j|d
ddddgdddddg}|d k	r|\}}}}}}|} j|ddgddg}|d krtd d S |d jd }|d jd }|jd }n|d k	r|\}}}}|}|jd }|jd }n|d k	r|\}}}|}|jd }n|d k	rbt|dkrb|d dd  \}}}}|}|jd }|jd }nD|d k	r|\}}}}}|}|}|jd }|jd }ntd d S  j|ddddgddddg}d \}} |d k	r|\}}}} ntd! d S d"\}!}" j|d
ddgdddg}# j|d#d
ddgddddg}$ j|ddd$d#dddgdddddddg}% j|dd$d#dddgddddddg}& j|dddd$d#dddgddddddddg}' j|ddd$d#dddgdddddddg}( j|dd#dd#d$d#dddg	dddddddddg	})|#d k	r<|#\}}*}+|*jd }!n|$d k	r^|$\}}}*}+|*jd }!n|%d k	r~ 	|%d jd }"n|&d k	r 	|&d jd }"nb|'d k	r|'d jd }"nH|(d k	r|(d jd }"n.|)d k	r 	|)d jd }"ntd% d S d"\},}-d }.d }/d }0 j| ddd
dd&dgddddddg}1 j| dd&dddgdddddg}2 j| dd
d&dddgddddddg}3 jj| ddddd
d&dddg	dddddddddg	fddddddd
dddd
d&dddgdddddddddddddddgfdddddddddd
dddd
d&dddgddddddddddddddddddgfdddddddd
dddd
d&dddgddddddddddddddddgfdddddd
dddd
d&dddgddddddddddddddgfddd
dddd
d&dddgdddddddddddgfddd
ddddd
d&dddgddddddddddddgfddd
dddd
d&dddgdddddddddddgfddd
dddd
d&dddgdddddddddddgfg	d d\}}4} j| dd
d
d&dddddg	dddddddddg	}5|1d k	
rb|1\}6}}7}}8}9|1}. j|7ddgddg}:|:d k
r,td' d S |:d jd },|:d jd };|7jd }-||;ks`t
n|2d k	
r|2\}}8}}<}9|2}.|8jd }-n|3d k	
r|3\}}7}8}}<}9|3}.|7jd },|7jd }-n|4d k	rt|4dkr|4d d(d  \}<}9|4d d)d* \}7}8|4}.|7jd },|7jd }-nH|5d k	rR|5\	}}7}0}8}/}}<}}9|5}.|7jd },|7jd }-ntd+ d S d }=d }>d }? j| ddd&dgddddg}@ j| d&dddgddddg}A j| d
d&dddddgdddddddg}B|@d k	r|@\}C}}D}E|@}=nL|Ad k	r|A\}D}}F}E|A}=n0|Bd k	r2|B\}?}D}>}}F}}E|B}=ntd, d S |Ejd |9jd krz|9jd |jd krztd- d S d.}G||kr |	|
|C|6||||Ejd std/ d S |	jd }Gn|||fkrr ||F|<||Ejd std/ d S |jd }G|>r|>jd n|Ejd |Djd< |/r2|/jd n|9jd |8jd< |?d kr\|8jd0 |8jd< ||krr|dd  } fd1d2}H|?r@|0r@ jd}I|Id0 }Jtjd|0jd g|Jg|Id3}K|Kjtd4ddddgg  jd}L|Ld0 }Mtjd|?jd g|Mg|Ld3}N|Njtd4ddddgg |H|<}O|Od krFtd5 d S  jjdd6d7}Ptjd|Kjd |Ojd g|Pd0 g|Pd3}Q jjdd8d7}Rtjd|Njd |Ojd g|Rd0 g|Rd3}S|Q}8|S}D j|O  j|K  j|N  j|Q  j|S  j j|Oj<  j j|Kj<  j j|Nj<  j j|Qj<  j j|Sj<  |Ejd |G|D|8||!|"|,||-|}T|Td kr|td9 d S  j|T  j j|Tj<  j|dd   ||kr j|d kr|d d n
|d d(  n&|d d g}U|D ]}V |V|U q j| |.|1kr2 j|.d d(  n|.|2krn j|.d   j|.d   j|.d  n|.|3kr j|.d   j|.d   j|.d   j|.d  nf|.|5kr j|.d   j|.d  n:|.|4kr |.d d |.d d g}U|.D ]}V |V|U q|=|@kr@ j|=d d(  n*|=|Akrj j|=d   j|=d  d: _d S );N>   r   r   r   MatMulReshape	Transposer;   r   Z	AllReducez0fuse_rotary_attention: failed to match qkv nodes)r   r   r   r:   ZExpandr<   ZWhereZEqualr=   r>   r@   ZConstantOfShape   r?      )output_name_to_noder   rA   zDfuse_rotary_attention: failed to match past/present concat in v path	   z-fuse_rotary_attention: failed to match v pathZSoftmaxDivNNz/fuse_rotary_attention: failed to match qk nodes)r   r   rB   Subz;fuse_rotary_attention: failed to match attention mask nodesRotaryEmbeddingzDfuse_rotary_attention: failed to match past/present concat in k pathz.fuse_rotary_attention: failed to match k nodesz.fuse_rotary_attention: failed to match q nodeszKfuse_rotary_attention: failed to find the same root_input for q, k, v pathsr   z;fuse_rotary_attention: failed to verify runtime shape pathsZ	_output_0c           
         s   j | dd}|dkr&td dS  j |jd } j |jd }|dksZ|dkrhtd dS |d }|d }|| } j jd	d
d} j |dkr j|t	j
dg|gdd  j jddd}tjd|jd |jd |g|d g|d}	|	jtddg |	S )zDetect num_heads and hidden_size for ONNX model from phi-2
            Args:
                reshape_q (NodeProto): reshape node for q
            Returns:
                hidden_size_concat_node(NodeProto): Concat node to be used by reshape
            r:   r;   NzEfuse_rotary_attention: failed to trace the concat node from reshape_qr?   r^   zMfuse_rotary_attention: failed to get constant nodes of num_heads or head_sizer   ZInitializerr   Zname_prefixF)r+   	data_typedimsvalsrawZhidden_size_concatZoutput_0r(   Zaxis)r   Zmatch_parentr/   r0   Zget_constant_valuer   r1   get_initializeradd_initializerr   ZINT64r	   r3   r5   r2   r6   )
rS   concatZnum_head_constant_nodeZhead_size_constant_nodeZnum_head_valueZhead_size_valuer   Zhidden_size_initilizerZhidden_size_reshape_node_namehidden_size_concat_noder   r   r   create_hidden_size_concat_node  sB    


zBFusionRotaryAttention.fuse.<locals>.create_hidden_size_concat_noder(   permz?fuse_rotary_attention: failed to create hidden_size_concat_nodeconcat_k_halfrk   concat_q_halfzSfuse_rotary_attention: failed to create multi-head attention with rotary embeddingsT)op_typer   rD   r/   r0   Zmatch_parent_paths_allr   r   lenZreshape_add_qkr.   rQ   rZ   r+   r1   r	   r3   r5   r2   r6   nodes_to_addappendthis_graph_namenode_name_to_graph_namer9   nodes_to_removeZ&add_nodes_to_remove_with_nodes_to_keepprune_graph)Wr   Znormalize_nodeinput_name_to_nodesr`   Z	qkv_nodesZqkv_nodes_1Zqkv_nodes_2Zqkv_nodes_3rL   rE   rF   Z
matmul_qkvrR   r#   r%   Zpast_seq_lenZv_nodesZadd_vZ	v_nodes_1Z	v_nodes_2Z	v_nodes_3Z	v_nodes_4Z	v_nodes_5rI   rW   rJ   Zmatmul_vrV   Ztranspose_vrU   Zqk_nodesr!   Z	matmul_qkr    Z
add_qk_strZattn_mask_nodes_1Zattn_mask_nodes_2Zattn_mask_nodes_3Zattn_mask_nodes_4Zattn_mask_nodes_5Zattn_mask_nodes_6Zattn_mask_nodes_7Zslice_mask_1Zslice_mask_2r"   r$   Zk_nodesZslice_krw   Z	k_nodes_1Z	k_nodes_2Z	k_nodes_3Z	k_nodes_4Z	k_nodes_5rH   rY   Zrotary_kZmatmul_krX   Zshared_past_seq_lenrT   Zq_nodesZslice_qrx   Z	q_nodes_1Z	q_nodes_2Z	q_nodes_3rG   Zrotary_qZmatmul_qrS   Zroot_outputru   Zk_transpose_node_nameZk_tranpose_output_nameZk_transpose_nodeZq_transpose_node_nameZq_tranpose_output_nameZq_transpose_noders   Zconcat_k_reshape_node_nameZconcat_k_reshape_nodeZconcat_q_reshape_node_nameZconcat_q_reshape_nodenew_nodeZnodes_to_keepZ	temp_pathr   rt   r   fuseF  s   






"lp

























 &"  %  )


















,





  

5






,






zFusionRotaryAttention.fuse)r   r   r   r   r   r   N)__name__
__module____qualname____doc__r   intr   strr   r   r7   r   r9   rQ   rZ   r   __classcell__r   r   r   r   r      s>          
8 Ir   c                       s^   e Zd Zed fddZeedddZeddd	Ze	e	e	e	e	d
ddZ
dd Z  ZS )FusionRotaryEmbeddings)r   c                    s*   d| _ t || j | j | j d dg d S )Nrg   z.1r   )	base_namer   r   )r   r   r   r   r   r   U  s    zFusionRotaryEmbeddings.__init__)rot_emb_nodefunctionc                    s   g g  }}|j D ]X}|jdkr|jg kr|jd |jkr|| t|j|jd }||j|  qg }|D ]6}|jd j}	| j	
d|	_| j	|	 ||	j qrt||D ]>\ }
tt fdd| j	j	jj }|D ]}t| |
 qq|S )NConstantr   c                    s
    | j kS N)r   )entryZextra_outputr   r   <lambda>o      z?FusionRotaryEmbeddings.reassign_extra_outputs.<locals>.<lambda>)nodery   r   r   r|   listindexr5   tr   r1   r+   rq   zipfiltergraphr   Zreplace_node_input)r   r   r   Zextra_constantsextra_outputsZfn_nodeZoutput_indexZextra_initializersZextra_constantZconstant_tensorprotoZextra_initializerZnodes_to_updateZnode_to_updater   r   r   reassign_extra_outputs\  s"    

$
z-FusionRotaryEmbeddings.reassign_extra_outputsr   c                    sB  | j | j}| j ddgddg}|d k	r8|\}}ntd d S |jd jd g}tt	fdd| j j j
j}tt	fdd| j j j
j}d	\}	}
t|dkrt|dkr| j |	d kr| j |
d krt|d jd j }t|d jd j }tj|	tjt|j|  d
}| j || j tj|
tjt|j|  d
}| j || j | j|d |d g ||	|
g j}t|dkrtt	fdd| j j j}t|dkst|  |d  tt	 fdd|}t|dksttj!| j|||dd}d|_"| j#| |S )Nr\   r[   r   z.fuse_rotary_embeddings: failed to match MatMulr;   c                    s   | j d  jd kS )Nr   r?   r   r   Zconstantr   r   r   r     r   zOFusionRotaryEmbeddings.create_rotary_embeddings_from_function.<locals>.<lambda>c                    s   | j d  jd kS )Nr   r^   r   r   r   r   r   r     r   	cos_cache	sin_cacher+   rl   rm   rn   c                    s   | j  jkS r   )r+   ry   )fnr   r   r   r     r   c                    s   |  kS r   r   )Zoutput_name)r   r   r   r     r   r)   r*   r+   Zinterleavedr,   )$r   r1   r   rD   r/   r0   r   r   r   r   r   r   rz   rp   r
   to_arrayr5   r   squeezer	   make_tensorr   FLOATshapeflattentolistrq   r}   r   r2   Z	functionsr.   r   r3   r4   r|   )r   r   rotary_emb_node_nameZmatmul_pathZreshape_nodeZmatmul_nodeZrotary_emb_inputscos_cache_nodesin_cache_nodecos_cache_namesin_cache_namer   r   cos_cache_tensorsin_cache_tensorZrotary_emb_outputsfuncrotary_emb_noder   )r   r   r   &create_rotary_embeddings_from_functionu  sv    





z=FusionRotaryEmbeddings.create_rotary_embeddings_from_function)rK   position_ids	cos_slice	sin_slicer   c                    s  | j | j}tt fdd| j j jj}ttfdd| j j jj}d\}	}
t|dkr|t|dkr|| j |	d kr|| j |
d kr|t	
|d jd j }t	
|d jd j }|jd }|d d d |d f }|d d d |d f }tj|	tjt|j|  d}| j || j tj|
tjt|j|  d}| j || j | j|d |d g tj| j|||	|
g|g|dd	}d
|_|S )Nc                    s   | j d  kS Nr   r   r   )r   r   r   r     r   zLFusionRotaryEmbeddings.create_rotary_embeddings_from_nodes.<locals>.<lambda>c                    s   | j d  kS r   r   r   )r   r   r   r     r   r   r;   r   r?   r   r   r,   )r   r1   r   r   r   r   r   rz   rp   r
   r   r5   r   r   r   r	   r   r   r   r   r   rq   r}   r   r2   r3   r4   )r   rK   r   r   r   r   r   r   r   r   r   r   r   Z	head_sizer   r   r   r   )r   r   r   #create_rotary_embeddings_from_nodes  sR    





z:FusionRotaryEmbeddings.create_rotary_embeddings_from_nodesc           %         s*	  | j |jkr|jdkrd S d  |jdkrt|jdksD|jd dkrRtd d S | |  d krrtd d S | j| t	t
 fdd| jjjj}t|dkst| jjjj|d	  n4| j|d
ddddgdd	d	d	d	g}| j|d
ddddgdd	d	d	d	g}|p|}| j|d
ddddddddg	dd	d	d	dd	d	d	d	g	}| j|d
ddddddddg	dd	d	d	dd	d	d	d	g	}	|p|	}
|d ks|
d krtd d S | j|d
dddgdd	dd	g}| j|d
dddgdd	dd	g}|p|}| j|d
dddddddgdd	ddd	d	d	d	g}| j|d
dddddddgdd	ddd	d	d	d	g}|p^|}|d kst|d krtd d S |d j|d jks|d j|
d jks|d j|d jks|d j|
d jkrtd d S | j|d
dgd	d	g}| j|d
dgd	d	g}|p |}|d kr:td d S d\}}}| j|d
ddddddddg	ddd	d	d	d	dd	d	g	}| j|d
dddddddgddd	d	d	d	dd	g}| j|d
ddddddgddd	d	dd	d	g}| j|d
dddddgddd	d	dd	g}|d k	r|}|d jd	 }n|d k	r8|}|d jd	 }nf|d k	rd|}|d jd	 }|d jd }n:|d k	r|}|d jd	 }|d jd }ntd d S d\}}| j|d
ddddddddg	d	dd	d	d	d	dd	d	g	}| j|d
dddddddgd	dd	d	d	d	dd	g}| j|d
ddddddgd	dd	d	dd	d	g}| j|d
dddddgd	dd	d	dd	g} |d k	r||}|d jd	 }n|d k	r|}|d jd	 }nf|d k	r|}|d jd	 }|d jd }n:| d k	r| }|d jd	 }|d jd }ntd d S |dkr| j|d d gdg}!| j|d d gdg}"|!d ksd|"d ksd|!d	 j|"d	 jkrrtd! d S |"d	 jd	 }ng }!g }"d"\}#}$||kr||ks||kr||kr|d# j|d# jks|d j|d jkrtd$ d S n||kr
||ks||kr|| kr|d j|d jkrBtd% d S | j|d ddgdd	g}#| j|d dddgd	d	d	g}$|#d ks|$d ks| j|#d jd	 d ks|$d jdkrtd& d S n
td' | |d jd	 ||||jd	   d krtd d S | |g | |d d  | |d d  | |d d  | |
d d  | |d d  | | | | | |!d d  | |"d d  |#d k	rt| j|#d	 dkr| |# |$d k	r| |$d d  | | j  | j| j j< | j  d(| _d S ))Nr   >   r_      r;   >   pos_idsr   position_idpos_idposzLfuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding functionz=fuse_rotary_embeddings: failed to create RotaryEmbedding nodec                    s   | j  jd kS r   )r+   r   r   r   r   r   r     r   z-FusionRotaryEmbeddings.fuse.<locals>.<lambda>r   r@   r:   ZNegrA   r]   r<   rd   r=   r>   z9fuse_rotary_embeddings: failed to match x2 in rotate_halfr?   z9fuse_rotary_embeddings: failed to match x1 in rotate_halfra   zCfuse_rotary_embeddings: failed to match common input in rotate_halfz8fuse_rotary_embeddings: failed to match x in rotate_half)Nr   r   ZSqueezerc   rj   z>fuse_rotary_embeddings: failed to match sin path in apply_rope)Nr   r   r\   zGfuse_rotary_embeddings: failed to match position ids path in apply_ropere   rh   zdfuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cachezRfuse_rotary_embeddings: failed to match common Add node in sin cache and cos cachezKfuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len pathsz:fuse_rotary_embeddings: failed to match common cache pathsT)r   ry   rz   r   r/   r0   r   r   r|   r   r   r   r   Z
value_infor.   removerD   r+   Zfind_graph_inputr   r   Zadd_nodes_to_removeZget_childrenr8   r}   r~   r{   r   )%r   r   r   r`   Zold_shape_inferZrotate_half_x2_path_1_1Zrotate_half_x2_path_1_2Zrotate_half_x2_path_1Zrotate_half_x2_path_2_1Zrotate_half_x2_path_2_2Zrotate_half_x2_path_2Zrotate_half_x1_path_1_1Zrotate_half_x1_path_1_2Zrotate_half_x1_path_1Zrotate_half_x1_path_2_1Zrotate_half_x1_path_2_2Zrotate_half_x1_path_2Zx_path_1Zx_path_2Zx_pathZsin_pathr   r   Z
sin_path_1Z
sin_path_2Z
sin_path_3Z
sin_path_4Zcos_pathr   Z
cos_path_1Z
cos_path_2Z
cos_path_3Z
cos_path_4Zposition_ids_from_sin_pathZposition_ids_from_cos_pathZpast_seq_len_pathZcurr_seq_len_pathr   r   r   r     s   






























,






$

zFusionRotaryEmbeddings.fuse)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   T  s   L8r   )loggingtypingr   r   Zfusion_attentionr   Zfusion_baser   Zonnxr   r   r   r	   r
   Z
onnx_modelr   	getLoggerr   r/   r   r   r   r   r   r   <module>   s    
        L