U
    T?hL                     @   sx   d dl mZ d dlmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZmZ d dlmZ eeZG dd	 d	eZdS )
    )	getLogger)TupleUnionN)Fusion)NumpyHelper)	NodeProtoTensorProtohelper)	OnnxModelc                
       s   e Zd ZdZeeeeeed fddZd&eeedddZ	d	d
 Z
d'eeeeeef dddZeeeeeeeeedf dddZeeeeeeeeedf dddZdd Zdd Zdd Zdd Zdd Zedd d!Zd"d# Zd$d% Z  ZS )(FusionAttentionUnetzB
    Fuse Attention subgraph of UNet into one Attention node.
    )modelhidden_size	num_headsis_cross_attentionenable_packed_qkvenable_packed_kvc                    sL   t  ||r|rdnddg || _|| _|| _|| _|| _d| _d| _d S )N	AttentionMultiHeadAttentionZLayerNormalizationT)	super__init__r   r   r   r   r   num_heads_warninghidden_size_warning)selfr   r   r   r   r   r   	__class__ `/var/www/html/venv/lib/python3.8/site-packages/onnxruntime/transformers/fusion_attention_unet.pyr      s    	zFusionAttentionUnet.__init__F)	reshape_q	is_torch2returnc                 C   s   d}|rj| j |d}|r|jdkrt|jdkr| j |jd }t|tjrt	|j
dgkrt|}n:| j |jd }t|tjrt	|j
dgkrt|d }t|tr|dkr|S dS )zDetect num_heads from a reshape node.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
        Returns:
            int: num_heads, or 0 if not found
        r      Concat      )r   Z
get_parentop_typeleninputZget_constant_value
isinstancenpZndarraylistshapeint)r   r   r   r   Zreshape_parentZq_shape_valuer   r   r   get_num_heads3   s    	
z!FusionAttentionUnet.get_num_headsc                 C   s*   | j |jd }|r&t|jd S dS )zDetect hidden_size from LayerNormalization node.
        Args:
            layernorm_node (NodeProto): LayerNormalization node before Q, K and V
        Returns:
            int: hidden_size, or 0 if not found
        r#   r   )r   get_initializerr&   r   to_arrayr*   )r   layernorm_nodeZlayernorm_biasr   r   r   get_hidden_sizeO   s    z#FusionAttentionUnet.get_hidden_size)r   r/   r   r   c                 C   s   |  ||}|dkr| j}| jdkrT|| jkrT| jrTtd| j d| d d| _| |}|dkrl| j}| jdkr|| jkr| jrtd| j d| d d| _||fS )aF  Detect num_heads and hidden_size.

        Args:
            reshape_q (NodeProto): reshape node for Q
            is_torch2 (bool): graph pattern is from PyTorch 2.*
            layernorm_node (NodeProto): LayerNormalization node before Q, K, V
        Returns:
            Tuple[int, int]: num_heads and hidden_size
        r   z--num_heads is z. Detected value is z. Using detected value.Fz--hidden_size is )r,   r   r   loggerwarningr0   r   r   )r   r   r/   r   r   r   r   r   r   get_num_heads_and_hidden_size\   s"    
z1FusionAttentionUnet.get_num_heads_and_hidden_sizeN)q_matmulk_matmulv_matmulr   r   r&   outputr   c           %   
   C   sH  | j  }|r^|jd |ks6|jd |ks6|jd |krtd|jd |jd |jd  dS nV|jd |ks|jd |jd ks|jd |krtd|jd |jd |jd  dS |dkr|| dkrtd| d|  dS | j|jd }	| j|jd }
| j|jd }|	r*|
r*|s.dS |	j}t|	}t|
}t|}td|j	 d	|j	 d
|j	 d|  |r~|j	|j	ks|j	|j	krdS |j	d }|dkr||krt
d| d| dtt|j	dd }| jr8| jd}|}|}|| }t||||||||||||g||d | }| jjddd}| j|d ||j	d |j	d g|d tjd|jd |d g|d g|d}| j| j|j< | j|d tjdgdd|d|gdd tjd|d |d g|d g|d d}| j| j|j< | j||g | j|||g nBtj|||fdd}d| }| jd }| j|d! |||g|d n| jd}| jr$|j	|j	krdS |j	d }|j	d }||kst|j	d }|j	d }|j	d }||kr||kst|}|}|| }t||||||||g||d" | }| jjdd#d}| j|d ||j	d |j	d g|d tjd|jd |d g|d g|d}| j| j|j< | j|d tjdgdd|d"|gdd tjd|d |d g|d$ g|d d}| j| j|j< | j||g | j||g tj d|gtj!d%} d| }!| j|d& ||!g| d |r| jsz||d! |d& g}"n
|d g}"n>| js|j"d |j"d |j"d |d& g}"n|j"d |d$ g}"tj|r| jsd nd|"|g|d}#d'|#_#|#j$t%d(|g |r| jsd)n d*&| jr(d+n| jr4d,nd-}$| '|$ |#S ).  Create an Attention node.

        Args:
            q_matmul (NodeProto): MatMul node in fully connection for Q
            k_matmul (NodeProto): MatMul node in fully connection for K
            v_matmul (NodeProto): MatMul node in fully connection for V
            num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
            hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
            input (str): input name
            output (str): output name

        Returns:
            Union[NodeProto, None]: the node created or None if failed.
        r   RFor self attention, input hidden state for q and k/v shall be same. Got %s, %s, %sNXFor cross attention, input hidden state for q and k/v shall be different. Got %s, %s, %sinput hidden size # is not a multiple of num of heads r    qw= kw= vw= hidden_size=Input hidden size (,) is not same as weight dimension of q,k,v (:). Please provide a correct input hidden size or pass in 0r      MatMul
MatMul_QKVZname_prefix_weightname	data_typedimsvals_outinputsoutputsrJ   _reshape_shape   FrJ   rK   rL   rM   rawReshape
_qkv_input_reshape)axisr   Z_qkv_weightr#   	MatMul_KV	_kv_inputZdtype	_qkv_biascom.microsoftr   Attention (self attention)MultiHeadAttention ({})self attention with packed qkvcross attention with packed kvcross attention)(r   r&   r1   debugr   r-   rK   r   r.   r*   
ValueErrorr+   r(   prodr   create_node_namedstackreshapeadd_initializerr	   	make_nodethis_graph_namenode_name_to_graph_namerJ   r   INT64nodes_to_addextendnodes_to_removestackr   AssertionErrorzerosfloat32r7   domain	attributemake_attributeformatincrease_counter)%r   r4   r5   r6   r   r   r&   r7   is_self_attentionq_weightk_weightv_weightZ
float_typeqwkwvw
qw_in_sizeqw_out_sizeattention_node_namecnh
qkv_weightmatmul_node_namematmul_nodereshape_nodeZqkv_weight_dim
kw_in_size
vw_in_sizekw_out_sizevw_out_size	kv_weightqkv_biasqkv_bias_dimattention_inputsattention_nodecounter_namer   r   r   create_attention_node~   sj   *0


(
. 
	





2		
z)FusionAttentionUnet.create_attention_node)q_matmul_addk_matmul_addv_matmul_addr   r   r&   r7   r   c           F   
   C   s&  | j  }| j|dd}	| j|dd}
| j|dd}| |}|dkrNdS |\}}| |}|dkrldS |\}}| |}|dkrdS |\}}|r@|	jd |ks|
jd |ks|jd |krtd|	jd |
jd |jd  dS |jd |ks|jd |ks|jd |krtd|jd |jd |jd  dS n|	jd |ksv|
jd |jd ksv|
jd |krtd|	jd |
jd |jd  dS |jd |ks|jd |jd ks|
jd |krtd|jd |jd |jd  dS |dkr*|| dkr*td| d	|  dS | j|	jd
 }| j|
jd
 }| j|jd
 }|rr|rr|svdS |jdkrtd dS t	
|}t	
|}t	
|}td|j d|j d|j d|  |r|j|jks|j|jkrdS |jd }|dkr0||kr0td| d| dtt|jd
d }| jr| jd}|}|}|| } t|||| |||| |||| g||d |  }!| jjddd}"| j|"d tj|!jd |!jd
 g|!d tjd|
jd |"d g|"d g|"d}#| j| j|#j< |jd }$| j|$tjdgdd|| gdd | jjd d!d}%tjd |jd |$g|%d g|%d}&| j| j|&j< | jjd d"d}'tjd |jd |$g|'d g|'d}(| j| j|(j< | jjd d#d})tjd |jd |$g|)d g|)d}*| j| j|*j< | jjd$d%d}+tjd$|&jd |(jd |*jd g|+d g|+d},|,jt d&dg | j| j|,j< |,jd }-| j|-tjdgdd|d |  gdd | jjd d'd}.tjd |,jd |-g|.d g|.d}/| j| j|/j< | jjd(d)d}0tjd(|/jd |#jd g|0d g|0d}1| j| j|1j< |0d }2| j|2tjd*gdd|d| gdd tjd |1jd |2g|d+ g|0d, d}3| j| j|3j< | j!|#|&|(|*|,|/|1|3g | j"|	|
||||g ndS nz| jd}| j#
r,|j|jkrdS |jd }4|jd }5|4|5kst$|jd
 }|jd
 }6|jd
 }7||7kr0|6|7ks4t$|4}|}|6| } t|||| |||| g||d- |  }8| jjdd.d}"| j|"d tj|8jd |8jd
 g|8d tjd|
jd |"d g|"d g|"d}#| j| j|#j< |jd }9| j|9tjdgdd|| gdd | jjd d"d}'tjd |jd |9g|'d g|'d}(| j| j|(j< | jjd d#d})tjd |jd |9g|)d g|)d}*| j| j|*j< | jjd$d/d}:tjd$|(jd |*jd g|:d g|:d};|;jt d&dg | j| j|;j< |;jd }<| j|<tjdgdd|d- |  gdd | jjd d0d}=tjd |;jd |<g|=d g|=d}>| j| j|>j< | jjd(d1d}?tjd(|>jd |#jd g|?d g|?d}@| j| j|@j< |?d }2| j|2tjd*gdd|d-| gdd tjd |@jd |2g|d2 g|?d, d}3| j| j|3j< | j!|#|(|*|;|>|@|3g | j"|
|||g ndS tj%d|gtj&d3}Ad| }B| j|d4 tj|Bg|Ad |
r| j
sxdS |d+ g}Cn| j#
sdS |jd |d2 g}Ctj|
r| j
sd5nd|C|g|d}Dd6|D_'|Djt d7|g |
r| j
sd8n d9(| jrd:n| j#rd;nd<}E| )|E |DS )=r8   rE   r   Nr9   z_For self attention, input hidden state for LoRA q and k/v weights shall be same. Got %s, %s, %sr:   zeFor cross attention, input hidden state for LoRA q and k/v weights shall be different. Got %s, %s, %sr;   r<   r    
   zBweights are in fp16. Please run fp16 conversion after optimizationr=   r>   r?   r@   rA   rB   rC   r   rD   rF   rG   rH   rI   rN   rO   rR   r"   FrT   rV   ZReshape_LoRA_QZReshape_LoRA_KZReshape_LoRA_Vr!   ZConcat_LoRA_QKVrY   ZReshape_LoRA_QKVAddZAdd_Weights_QKVrS   rW   rX   r#   rZ   ZConcat_LoRA_KVZReshape_LoRA_KVZAdd_Weights_KVr[   r\   r]   r   r^   r   r_   r`   ra   rb   rc   )*r   r   match_parentmatch_lora_pathr&   r1   rd   r-   rK   r   r.   r*   re   r+   r(   rf   r   rg   rh   ri   rj   r   FLOATr	   rk   rl   rm   rJ   rn   r7   rw   rp   rx   ro   rq   r   rs   rt   ru   rv   ry   rz   )Fr   r   r   r   r   r   r&   r7   r{   r4   r5   r6   Zq_lora_nodesZq_lora_last_nodeZq_lora_matmul_1Zk_lora_nodesZk_lora_last_nodeZk_lora_matmul_1Zv_lora_nodesZv_lora_last_nodeZv_lora_matmul_1r|   r}   r~   r   r   r   r   r   r   r   r   r   r   r   r   Zlora_weight_shape_tensor_nameZq_lora_reshape_node_nameZq_lora_reshape_nodeZk_lora_reshape_node_nameZk_lora_reshape_nodeZv_lora_reshape_node_nameZv_lora_reshape_nodeZqkv_lora_concat_node_nameZqkv_lora_concat_nodeZ'reshaped_lora_weights_shape_tensor_nameZqkv_lora_reshaped_node_nameZqkv_lora_reshaped_nodeZadd_weights_node_nameZadd_weights_nodeZshape_tensor_namer   r   r   r   r   r   Z kv_lora_weight_shape_tensor_nameZkv_lora_concat_node_nameZkv_lora_concat_nodeZ*reshaped_kv_lora_weights_shape_tensor_nameZkv_lora_reshaped_node_nameZkv_lora_reshaped_nodeZadd_kv_weights_node_nameZadd_kv_weights_noder   r   r   r   r   r   r   r   create_attention_node_lora~  s   


*6	



(
. 


	






2

	
	
z.FusionAttentionUnet.create_attention_node_lorac              	   C   s  |  |||rd S | j|dd}|d kr@| js@| j|dd}|d krLd S |jd }|| }d }|D ]}|jdkrf|} q~qf|d krd S | ||p| ||}	|	d k	r|	\}
}}}}}}|}| |||
\}}|dkrt	
d d S | j||||||jd |jd d}|d krd S n| ||p6| ||}	|	d krFd S |	\}
}}}}}}|}| |||
\}}|dkrt	
d d S | j||||||jd |jd d}|d krd S | |||
\}}|dkrt	
d d S | j| | j| j|j< | j||g d| _d S )Nr   r   rV   *fuse_attention: failed to detect num_headsr&   r7   T)fuse_a1111_fp16r   r   r   r7   r$   match_qkv_torch1match_qkv_torch2r3   r1   rd   r   match_qkv_torch1_loramatch_qkv_torch2_lorar   ro   appendrl   rm   rJ   rq   rp   prune_graph)r   normalize_nodeinput_name_to_nodesoutput_name_to_nodenode_before_layernorm
root_inputchildren_nodesskip_addnode	match_qkvr   reshape_qkvtranspose_qkvr   matmul_qmatmul_kmatmul_vattention_last_nodeq_num_headsq_hidden_sizenew_nodematmul_add_qmatmul_add_kmatmul_add_vr   r   r   fuseR  s    



	
 


	


zFusionAttentionUnet.fusec              
   C   s  |j d |krdnd}| j|ddddddg|dddddg}|dkrJdS |\}}}}}}| j|ddddgddddg}	|	dkrtd dS |	\}}}}
| j|d	d
dgdddg}|dk	r|\}}}nF| j|d	dd
dgddddg}|dk	r|\}}}}ntd dS | j|ddddgddddg}|dkrJtd dS |\}}}}| j|dddddgdddddg}|dkrtd dS |\}}}}}d||||||
fS )z.Match Q, K and V paths exported by PyTorch 1.*r   r    r   rE   rV   	TransposeN&fuse_attention: failed to match v pathSoftmaxMul'fuse_attention: failed to match qk path&fuse_attention: failed to match q path&fuse_attention: failed to match k pathFr&   r   match_parent_pathr1   rd   )r   r   r   another_input	qkv_nodes_r   r   
matmul_qkvv_nodesr   qk_nodes_softmax_qk_mul_qk	matmul_qk	_add_zeroq_nodes_transpose_qr   r   k_nodesr   r   r   r   r     sJ     
 

 

  

z$FusionAttentionUnet.match_qkv_torch1c                 C   s  |j d |krdnd}| j|dddddg|ddddg}|dkrFdS |\}}}}}| j|dddgdddg}	|	dkrtd dS |	\}}}
| j|d	dgddg}|dk	r|\}}ntd
 dS | j|ddddgddddg}|dkrtd dS |\}}}}| j|ddddgddddg}|dkrBtd dS |\}}}}| j|ddddddddgddddddddg}|dks|d |krtd dS d||||||
fS )z.Match Q, K and V paths exported by PyTorch 2.*r   r    r   rE   rV   r   Nr   r   r   r   r   r   SqrtDivCastSliceShapez*fuse_attention: failed to match mul_q pathTr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   mul_qr   r   r   r   _mul_kr   mul_q_nodesr   r   r   r     sL    



 
 


z$FusionAttentionUnet.match_qkv_torch2c                 C   s  |j d |krdnd}| j|dddddddg|ddddddg}|dkrNdS |\}}}}}}}| j|ddddgddddg}	|	dkrtd dS |	\}}}}
| j|d	d
dgdddg}|dk	r|\}}}nF| j|d	dd
dgddddg}|dk	r
|\}}}}ntd dS | j|ddddgddddg}|dkrPtd dS |\}}}}| j|dddddgdddddg}|dkrtd dS |\}}}}}d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 1 that contains LoRA patterns.*r   r    r   rE   rV   r   N+fuse_attention: failed to match LoRA v pathr   r   ,fuse_attention: failed to match LoRA qk path+fuse_attention: failed to match LoRA q path+fuse_attention: failed to match LoRA k pathFr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r     sJ     
 

 

  

z)FusionAttentionUnet.match_qkv_torch1_lorac                 C   s  |j d |krdnd}| j|ddddddg|dddddg}|dkrJdS |\}}}}}}| j|dddgdddg}	|	dkrtd dS |	\}}}
| j|d	dgddg}|dk	r|\}}ntd
 dS | j|ddddgddddg}|dkrtd dS |\}}}}| j|ddddgddddg}|dkrJtd dS |\}}}}| j|ddddddddgddddddddg}|dks|d |krtd dS d||||||
fS )zJMatch Q, K and V paths exported by PyTorch 2 that contains LoRA patterns.*r   r    r   rE   rV   r   Nr   r   r   r   r   r   r   r   r   r   r   r   z/fuse_attention: failed to match LoRA mul_q pathTr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   A  sL    



 

 


z)FusionAttentionUnet.match_qkv_torch2_lora)add_nodec                 C   s   | j |ddgddg}|d k	r0|\}}||fS | j |dddgdddg}|d k	rf|\}}}||fS | j |ddddgddddg}|d k	r|\}}}}||fS d S )NrE   r    r   r   )r   r   )r   r   Z
lora_nodesZlora_matmul_2_nodeZlora_matmul_1_nodeZlora_mul_noder   r   r   r   r   t  s2    



z#FusionAttentionUnet.match_lora_pathc              	   C   s  | j |ddgddg}|dkrD| j |ddgddg}|dkrDdS |\}}|jd }|| }d}	|D ]}
|
jdkrf|
}	 q~qf|	dkrdS | ||	}|dkrdS |\}}}}}}| j |dd}| j |dd}| j |dd}|dk	r|dk	r| js
||krn
||kr||ks"dS |jd |jd kr<dS |}| |dpX| |d}|dkrrt	
d dS | |}| j||||||jd |jd d	}|dkrdS | j| | j| j|j< | j||g d| _dS )
zPFuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extensionr   r   r   NrV   FTr   r   )r   r   r7   r$   match_qkv_a1111r   r   r&   r,   r1   rd   r0   r   ro   r   rl   rm   rJ   rq   rp   r   )r   r   r   r   
entry_pathZ_castr   r   r   r   r   r   r   r   r   r   r   r   Zcast_qZcast_kZcast_vr   r   r   r   r   r   r   r     s    

	


	
z#FusionAttentionUnet.fuse_a1111_fp16c              
   C   s~  |j d |krdnd}| j|ddddddg|dddddg}|dkrJdS |\}}}}}}	| j|	ddddgddddg}
|
dkrtd	 dS |
\}}}}| j|	d
d
dddgdddddg}|dk	r|\}}}}}ntd dS | j|ddddgddddg}|dkrtd dS |\}}}}| j|ddddgddddg}|dkrbtd dS |\}}}}||||||fS )zKMatch Q, K and V paths exported by A1111 (stable diffusion webui) extensionr   r    r   rE   rV   r   ZEinsumNr   r   r   r   r   r   r   r   )r   r   r   r   r   r   r   r   Zreshape_einsumZ
einsum_qkvr   r   r   r   Z	einsum_qkr   r   r   r   r   r   r   r   r   r     sD     
  
 

 

z#FusionAttentionUnet.match_qkv_a1111)F)F)__name__
__module____qualname____doc__r
   r+   boolr   r   r,   r0   r   r3   strr   r   r   r   r   r   r   r   r   r   r   __classcell__r   r   r   r   r      sb      
$
  
   WZ1405,Nr   )loggingr   typingr   r   numpyr(   Zfusion_baser   Zfusion_utilsr   Zonnxr   r   r	   Z
onnx_modelr
   r   r1   r   r   r   r   r   <module>   s   