U
    Mâh¾Ù  ã                9   @  s˜  d Z ddlmZ ddlZddlZddlZddlmZmZ ddl	Z	ddl	m
Z
 ddlmZ ddlmZmZmZmZmZmZ ddlmZmZmZ d	d
dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:d;d<d=d>d?d@dAg9ZejejdBdCZedDƒe dE¡e  dFdGdG¡ej!dHdIdJdJdKœdLd„ƒƒƒƒZ"edMƒej!dHdNœdOd„ƒƒZ#edPƒe  dFdF¡ej!dHdNœdQd„ƒƒƒZ$edRƒe  dFdF¡ej!dHdNœdSd„ƒƒƒZ%edTƒej!dHdNœdUd1„ƒƒZ&edVƒe dE¡e  dFdWdF¡ej!dHdNœdXd6„ƒƒƒƒZ'edYƒej!ddHdNœd[d"„ƒƒZ(ed\ƒe  dFdW¡ej!dHdNœd]d-„ƒƒƒZ)ed^e *d_d`da¡gdbedce *dddeda¡gdbedfe *dgdhda¡gdbedie *djd`dk¡gdbedle *dmdedk¡gdbedne *dodhdk¡gdbedpe *dqdedr¡gdbej!dsdtdsduœdvdw„ƒƒƒƒƒƒƒƒZ+edxƒe dEdZdZdZdZdZdZ¡ej!dHdNœdydz„ƒƒƒZ,ed{ƒe  dFdWdFdF¡ej!ddHdNœd|d„ƒƒƒZ-ed}ƒe  dFdWdFdF¡ej!dHdNœd~d5„ƒƒƒZ.edƒe  dFdWd€¡ej!ddHdNœdd„ƒƒƒZ/ed‚ƒej!dHdNœdƒd(„ƒƒZ0ed„ƒej!dHdNœd…d'„ƒƒZ1ed†ƒej!dHdNœd‡dˆ„ƒƒZ2ed‰ƒej!dHdNœdŠd‹„ƒƒZ3edŒƒej!dHdNœddŽ„ƒƒZ4edƒej!dHdNœdd
„ƒƒZ5ed‘ƒej!ddHdNœd’d	„ƒƒZ6ed“ƒej!dHdNœd”d#„ƒƒZ7ed•ƒej!dHdNœd–d.„ƒƒZ8ed—ƒej!dHdNœd˜d„ƒƒZ9ed™ƒe dE¡ej!dHdNœdšd„ƒƒƒZ:ed›ƒej!dHdNœdœd<„ƒƒZ;edƒe  dFdWdWdW¡ej!dHdNœdždŸ„ƒƒƒZ<ed ƒe  dFdWdWdWdW¡ej!dHdNœd¡d?„ƒƒƒZ=ed¢ƒe  dFdFdWdWdWd€¡ej!ddHdNœd£d=„ƒƒƒZ>ed¤ƒe  dFdWdWd€¡ej!ddHdNœd¥d8„ƒƒƒZ?ed¦ƒe  dFdWdWd€¡ej!ddHdNœd§d„ƒƒƒZ@ed¨ƒe  dFdW¡ej!ddHdNœd©d4„ƒƒƒZAedªƒej!dHdNœd«d2„ƒƒZBed¬ƒe  dFdFdWdW¡ej!ddHdNœd­d:„ƒƒƒZCed®ƒe  dFdFdWdW¡ej!ddHdNœd¯d9„ƒƒƒZDed°ƒe  dFdWdW¡ej!ddHdNœd±d>„ƒƒƒZEej!dHdNœd²d³„ƒZFed´ƒej!ddHdNœdµd„ƒƒZGed¶ƒed·ƒed¸ƒej!dHdNœd¹d0„ƒƒƒƒZHedºƒed»ƒed¼ƒej!dHdNœd½d3„ƒƒƒƒZIed¾ƒej!dHdIdIdIdId¿œdÀd,„ƒƒZJedÁƒej!dHdNœdÂd$„ƒƒZKedÃƒej!dHdNœdÄd&„ƒƒZLedÅƒej!dHdNœdÆd„ƒƒZMedÇƒe  dFdW¡ej!dHdNœdÈdÉ„ƒƒƒZNedÊƒejdEdZdËej!ddHdNœdÌd7„ƒƒƒZOedÍƒej!ddHdNœdÎd;„ƒƒZPedÏƒej!dHdNœdÐd@„ƒƒZQedÑƒej!dHdNœdÒd)„ƒƒZRedÓƒej!dHdNœdÔd „ƒƒZSedÕƒej!dHdNœdÖd„ƒƒZTed×ƒej!dHdNœdØd!„ƒƒZUedÙƒedÚƒej!dHdNœdÛdÜ„ƒƒƒZVedÝƒedÞƒej!dHdNœdßdà„ƒƒƒZWej!dHdNœdádâ„ƒZXej!dHdNœdãdä„ƒZYej!dHdNœdådæ„ƒZZedçƒe  dFdèdèdèdè¡ej!dHdNœdéd„ƒƒƒZ[edêƒej!dHdNœdëd*„ƒƒZ\edìƒe dEdZdZ¡e  dFdWdW¡ej!dHdNœdíd„ƒƒƒƒZ]edîƒe  dFdGdèdïdF¡ej!dHdðdñdòœdód%„ƒƒƒZ^edôƒe  dFdFdFdWdWdWdFdWdW¡	ej!dHdNœdõd„ƒƒƒZ_edöƒe  dFdFdGdG¡ej!dHdNœd÷d„ƒƒƒZ`edøƒej!dHdNœdùd„ƒƒZaedúƒej!ddHdNœdûd+„ƒƒZbedüƒej!dHdýdþœdÿd„ƒƒZced ƒej!dHdýdþœdd„ƒƒZdedƒej!dHdýdþœdd„ƒƒZeedƒej!dHdNœdd/„ƒƒZfedƒej!dHdIdœdd„ƒƒZged	ƒej!dHdIdœd
dA„ƒƒZhdS (  z(This file exports ONNX ops for opset 11.é    )ÚannotationsN)ÚOptionalÚSequence)Ú_C)Ú_onnx)Ú_type_utilsÚerrorsÚsymbolic_helperÚsymbolic_opset10Úsymbolic_opset9Úutils)Ú	_beartypeÚ	jit_utilsÚregistrationÚaddÚappendÚarangeÚargsortÚ
atleast_1dÚ
atleast_2dÚ
atleast_3dÚcatÚchunkÚ	clamp_maxÚ	clamp_minÚclampÚconstant_pad_ndÚcumsumÚDeleteÚembedding_bagÚembedding_renormÚflattenÚgatherÚhardtanhÚhstackÚim2colÚ
index_fillÚindexÚ
index_copyÚ	index_putÚinsertÚ
linalg_detÚlinalg_vector_normÚlogdetÚmasked_scatterÚmasked_selectÚmmÚnarrowÚnormalÚpadÚpixel_shuffleÚpopÚprim_constant_chunkÚreflection_padÚrelu6Ú	remainderÚreplication_padÚroundÚscatterÚselectÚsizeÚsortÚsplit_with_sizesÚsplitÚsqueezeÚstackÚtopkÚunbindÚ
unique_dimÚ	unsqueezeÚvstacké   )Zopsetzaten::hardtanhTÚvÚfzjit_utils.GraphContextz_C.ValueÚfloat)ÚgÚselfÚmin_valÚmax_valc                 C  s`   t j |t jj¡}| jdtj|| ¡ dd}| jdtj|| ¡ dd}tj	| d|||ddS )NÚConstant©Údtype©Zvalue_tÚClipé   ©Zopset_before)
r   ÚJitScalarTypeÚ
from_valueÚFLOATÚopÚtorchÚtensorrS   r	   Ú_op_with_optional_float_cast)rM   rN   rO   rP   Úscalar_type© r`   úM/var/www/html/venv/lib/python3.8/site-packages/torch/onnx/symbolic_opset11.pyr#   Y   s(     ÿþþ     ÿzaten::clamp©rM   c                   sº   t j‡ fdd„ƒ}tj |tjj¡}|tjjkrD|||ƒ}|||ƒ}t |¡rZtˆ ||ƒS t |¡rpt	ˆ ||ƒS t 
|¡dkr¢t 
|¡dkr¢tjˆ d|||ddS tˆ t	ˆ ||ƒ|ƒS d S )Nc                   s.   | d k	r&t  | ¡s&ˆ jd| | ¡ dS | S d S )NÚCast©Zto_i)r	   Ú_is_noner[   Ú	onnx_type)r]   rS   rb   r`   ra   Ú_cast_if_not_noneq   s    ýz clamp.<locals>._cast_if_not_noner   rU   rV   rW   )r   Úbeartyper   rX   rY   Ú	UNDEFINEDr	   re   r   r   Ú_get_tensor_rankr^   )rM   rN   ÚminÚmaxrg   r_   r`   rb   ra   r   n   s4    
 ÿ



ÿþ     ÿzaten::clamp_minc                 C  sb   | j d|tj |¡ ¡ d}t |¡dkrJt | ¡}tj	| d|||ddS tj	| d||ddS d S )Nrc   rd   r   rU   rV   rW   ZMax©
r[   r   rX   rY   rf   r	   rj   Úopset9Zunusedr^   )rM   rN   rk   rl   r`   r`   ra   r   “   s$    
     ÿ    ÿzaten::clamp_maxc                 C  sb   | j d|tj |¡ ¡ d}t |¡dkrJt | ¡}tj	| d|||ddS tj	| d||ddS d S )Nrc   rd   r   rU   rV   rW   ZMinrm   )rM   rN   rl   rk   r`   r`   ra   r   £   s$    
     ÿ    ÿzaten::relu6c                 C  sX   t j |t jj¡}| jdtjd| ¡ dd}| jdtjd| ¡ dd}t| |||ƒS )NrQ   r   rR   rT   é   )	r   rX   rY   rZ   r[   r\   r]   rS   r   )rM   Úinputr_   rO   rP   r`   r`   ra   r8   ³   s     ÿþþzaten::selectÚic                 C  s   | j d|||dS )NÚGather©Úaxis_i©r[   )rM   rN   Údimr'   r`   r`   ra   r=   Ä   s    zaten::index_putFc                   sô  t  |¡rt  |¡}n|g}t  ¡ rD|g| ||g }ˆjd|žŽ S t  |d¡}t|ƒdkr`|S t|ƒdkr tt|ƒƒD ]&}t  || ¡rzˆ 	d|| ¡||< qz|d }|dd … D ]}	t
 ˆ||	¡}q¶ˆ 	d|¡‰ ‡ ‡fdd„|D ƒ}ˆj	d|žd
diŽ}nº|d }|}
t  |
¡ržt  |¡}|d k	rF|dkrFt
 ˆ||
|¡S t  |
¡}t  |¡}|d k	r|d k	r||krt  ˆ|
tt||ƒƒ¡}
tˆ||
|ƒS ˆ 	d|¡‰ t  ˆ|dg¡}t jˆˆ 	d|¡dgt|ƒgtjgd}ˆj	d	ˆ |dd}t  |¡}|d k	r"|dkr"t
 ˆ||d ¡}t  ˆ||¡}tj |tjj¡}|tjjkr‚tj |tjj¡}||kr”ˆj	d|| ¡ d}n|r”t d|¡‚|ràˆj	dˆ 	d|¡tjdg| ¡ dd}ˆ 	d|||¡}tˆ||ƒ}nˆ 	d|||¡}|S )Nr)   Úbr   é   ZNonZeroÚShapec                   s(   g | ] }t  ˆt ˆ|ˆ d ¡dg¡‘qS )Néÿÿÿÿ)r	   Ú_unsqueeze_helperrn   Úexpand)Ú.0Úind©Zbroadcast_index_shaperM   r`   ra   Ú
<listcomp>è   s   ý  ÿzindex_put.<locals>.<listcomp>ÚConcatrt   rz   ©ZaxesZstartsZendsrs   rc   rd   z'self does not have a valid scalar type.ÚConstantOfShaperR   rT   Ú	ScatterND)r)   )r   ) r	   Ú_is_packed_listÚ_unpack_listÚis_caffe2_aten_fallbackÚatÚ
_parse_argÚlenÚrangeÚ_is_boolr[   rn   r   rj   Zmasked_fillr{   Úlistr.   Ú_slice_helperÚsysÚmaxsizer|   Ú_reshape_helperr   rX   rY   ri   rf   r   ÚSymbolicValueErrorr\   r]   rS   )rM   rN   Zindices_list_valueÚvaluesÚ
accumulateZindices_listÚargsZidx_r'   r~   Zbool_inpÚrankZ	mask_rankZ	self_rankZsub_data_shapeZvalues_shapeZself_scalar_typeZvalues_scalar_typeZzerosÚresultr`   r   ra   r)   Í   sœ    
ü(


ÿþý  ÿ 
   ÿ
 ÿ ÿ

ýzaten::pixel_shufflec                 C  s8   t  |¡}|d k	r&|dkr&t  dd¡S | jd||ddS )Né   r4   zonly support 4d inputZDepthToSpaceZCRD)Zblocksize_iÚmode_s)r	   rj   Ú_unimplementedr[   )rM   rN   Zupscale_factorr–   r`   r`   ra   r4   N  s    
zaten::upsample_nearest1dZupsample_nearest1dé   Znearest)Zdecoratezaten::upsample_nearest2dZupsample_nearest2dr˜   zaten::upsample_nearest3dZupsample_nearest3dé   zaten::upsample_linear1dZupsample_linear1dZlinearzaten::upsample_bilinear2dZupsample_bilinear2dzaten::upsample_trilinear3dZupsample_trilinear3dzaten::upsample_bicubic2dZupsample_bicubic2dZcubicÚstrÚint©Únamerv   Zinterpolate_modec                 C  s   t  | ||¡S ©N)r	   Z_interpolate_helperrŸ   r`   r`   ra   Ú_interpolateX  s    r¢   zaten::__interpolatec              	   C  s   t  | ||||||¡S r¡   )r	   Z__interpolate_helper)rM   rp   r>   Zscale_factorÚmodeZalign_cornersZrecompute_scale_factorZ	antialiasr`   r`   ra   Ú__interpolatey  s          ÿr¤   zaten::gatherc                 C  sD   t  |d¡rt  dd¡S t  ¡ r2|  d||||¡S | jd|||dS )Nrq   r"   zsparse_grad == TrueZGatherElementsrs   )r	   Ú_maybe_get_constrš   r‡   rˆ   r[   )rM   rN   rv   r'   Zsparse_gradr`   r`   ra   r"   ‹  s
    zaten::scatterc              	   C  s    t  ¡ r| jd||||ddS tj |¡}t  |¡}t  |¡rR| jd||||dS tj |¡|kr~| jd|tj |¡ 	¡ d}| jd||t
 | ||¡|dS d S )Nr<   Úsrc©Úoverload_nameZScatterElementsrs   rc   rd   )r	   r‡   rˆ   r   rX   rY   Ú_maybe_get_scalarÚ	_is_valuer[   rf   rn   Ú	expand_as)rM   rN   rv   r'   r¦   Zsrc_typer`   r`   ra   r<   –  s&    

ý    ÿzaten::cumsumÚnonec                 C  sn   | j dtj|tjdd}|rX| ¡  ¡ dkrXt |dd¡}| j d|t 	|¡ 
¡ d}n|}|   d	||¡}|S )
NrQ   rR   rT   zprim::Constantrq   rS   rc   rd   ZCumSum)r[   r\   r]   rž   ÚnodeÚkindr	   Ú
_get_constr   rX   rf   )rM   rN   rv   rS   Z
dim_tensorZparsed_dtypeÚcastZcsumr`   r`   ra   r   ®  s      ÿzaten::masked_selectc                 C  s$   t  | t  | ||¡¡}|  d||¡S )NÚGatherND)rn   Únonzeror«   r[   )rM   rN   Úmaskr'   r`   r`   ra   r/   ¾  s    zaten::masked_scatterc                 C  sr   t  | t  | ||¡¡}t | |t dg¡¡}tj| |t dg¡t dg¡t  | |t dg¡¡d}|  	d|||¡S )Nrz   r   r‚   r„   )
rn   r²   r«   r	   r‘   r\   Ú
LongTensorrŽ   r>   r[   )rM   rN   r³   Úsourcer'   r`   r`   ra   r.   Å  s    

ûz	aten::lenc                 C  sT   t  |¡s| ¡  ¡ dkr&|  d|¡S t| || jdt dg¡dƒ}t  | |dg¡S )Nzonnx::SplitToSequenceZSequenceLengthrQ   r   rT   )	r	   Ú_is_tensor_listr­   r®   r[   r>   r\   r´   Ú_squeeze_helper)rM   rN   Zsz_0r`   r`   ra   Ú_len×  s    ÿþr¸   zaten::__getitem_c                 C  s4   t  |¡r|  d||¡S ddlm} || ||ƒS d S )NÚ
SequenceAtr   )Ú
__getitem_)r	   r¶   r[   Ztorch.onnx.symbolic_opset9rº   )rM   rN   rq   Úgetitemr`   r`   ra   rº   ã  s    
rº   zaten::_set_itemc                 C  s   |   d||¡}|   d|||¡S )NÚSequenceEraseÚSequenceInsertru   )rM   Útensor_listrq   rJ   r`   r`   ra   Ú	_set_itemï  s    r¿   zaten::appendc                 C  s   |   d||¡S ©Nr½   ru   )rM   rN   r]   r`   r`   ra   r   ö  s    z	aten::addc                 C  sn   t  |¡r^t  |¡r^| ¡ }| ¡ dkr4t  dd¡S t  |¡}|}|D ]}|  d||¡}qF|S t 	| |||¡S )Nzprim::ListConstructr   z6does not support adding dynamic tensor list to anotherr½   )
r	   rª   r¶   r­   r®   rš   r†   r[   rn   r   )rM   rN   ÚotherÚalphaZtensor_list_nodeZtensorsÚlÚtr`   r`   ra   r   ü  s     ÿ
zaten::insertc                 C  s   |   d|||¡S rÀ   ru   )rM   rN   Úposr]   r`   r`   ra   r*     s    z	aten::popc                 C  s   |   d||¡S ©Nr¼   ru   ©rM   r¾   rv   r`   r`   ra   r5     s    zaten::Deletec                 C  s   |   d||¡S rÆ   ru   rÇ   r`   r`   ra   r     s    z	aten::catc                 C  s:   t  |¡rt | ||¡S t  |dd¡}| jd||dS d S )Nrq   rv   ÚConcatFromSequencers   )r	   r…   rn   r   r¯   r[   rÇ   r`   r`   ra   r      s    
zaten::stackc                 C  s<   t  |¡rt | ||¡S t  |dd¡}| jd||ddS d S )Nrq   rv   rÈ   rx   ©rt   Z
new_axis_i)r	   r…   rn   rC   r¯   r[   rÇ   r`   r`   ra   rC   +  s    
zaten::_unique2c           	      C  s$   | j d||dd\}}}}|||fS )NÚUniquer˜   )Úsorted_iÚoutputsru   )	rM   rN   ÚsortedÚreturn_inverseÚreturn_countsÚuÚindicesÚinverse_indicesÚcountsr`   r`   ra   Ú_unique25  s       ÿrÔ   zaten::unique_dimc           
      C  s&   | j d|||dd\}}}}	|||	fS )NrÊ   r˜   )rt   rË   rÌ   ru   )
rM   rN   rv   rÍ   rÎ   rÏ   rÐ   rÑ   rÒ   rÓ   r`   r`   ra   rF   ?  s        ÿz
aten::topkc              	   C  s   t j| ||||||dS )N)ÚlargestrÍ   Úout)r	   Z_topk_helper)rM   rN   Úkrv   rÕ   rÍ   rÖ   r`   r`   ra   rD   K  s          ÿz
aten::sortc                 C  s   t j| ||||dS ©N)Ú	decendingrÖ   ©r	   Z_sort_helper)rM   rN   rv   rÙ   rÖ   r`   r`   ra   r?   T  s    zaten::argsortc                 C  s   t j| ||||d\}}|S rØ   rÚ   )rM   rN   rv   rÙ   rÖ   Ú_rÑ   r`   r`   ra   r   [  s        ÿ
zaten::roundc                 C  sz   t  |¡s|S |dkr"|  d|¡S |  d|| jdt td|ƒ¡d¡}|  d|¡}|  d|| jdt tdd| ƒ¡d¡S )Nr   ZRoundÚMulrQ   é
   rT   rz   )r	   Ú_is_fpr[   r\   r]   Úpow)rM   rN   ZdecimalsÚmulr;   r`   r`   ra   r;   e  s    
$  ÿzaten::remainderc                 C  s4   t  |¡st  |¡r"t | ||¡S | jd||ddS )NÚModr   )Zfmod_i)r	   rÞ   rn   r9   r[   )rM   rp   rÁ   r`   r`   ra   r9   t  s    zaten::splitc              
     s  t  ||¡sòˆ jd|||d‰|d kr*ˆS t  |¡rÚtt  |¡ƒ|krÚ‡ fdd„t  |¡D ƒ}ˆ jdtjdgtjdd}ˆ jdtj|gtjdd}g }t	|ƒD ]2}	ˆ  d	|||	 ¡}
| 
ˆ  d
|||
|¡¡ |
}q¢|S ‡ ‡fdd„t	|ƒD ƒS t ˆ ||||¡S d S )NÚSplitToSequencers   c                   s   g | ]}t  ˆ |d g¡‘qS )r   )r	   r{   )r}   rJ   rb   r`   ra   r€   ‰  s   ÿzsplit.<locals>.<listcomp>rQ   r   rR   rT   ÚAddÚSlicec                   s2   g | ]*}ˆ   d ˆˆ j dtj|gtjdd¡‘qS )r¹   rQ   rR   rT   )r[   r\   r]   Úlong)r}   rq   ©rM   Z	split_outr`   ra   r€   —  s   ûý)r	   Z_is_split_staticr[   r…   rŠ   r†   r\   r]   rå   r‹   r   rn   rA   )rM   rN   Zsplit_size_or_sizesrv   Ú_outputsÚsplit_sizesÚstartÚaxisÚresrq   Úendr`   ræ   ra   rA   |  s6    ÿþ
þ  ÿú	zaten::split_with_sizesc                 C  s   t | ||||ƒS r¡   )rA   )rM   rN   rè   rv   rç   r`   r`   ra   r@   £  s    zaten::unbindc              	   C  sF   |d kr2| j d|| j dtjdtjdd|ddS t | |||¡S d S )Nrâ   rQ   rx   rR   rT   r   )rt   Ú
keepdims_i)r[   r\   r]   rå   rn   rE   )rM   rN   rv   rç   r`   r`   ra   rE   ª  s    ûc                 C  sz  t  |¡s0t  |¡r0t  |¡r0| jd|ddd}t | || jdt dg¡d¡}t  	|¡}|dkrx|  d|  d	|¡¡}n| jdtj|tj
d
d}|  d|  d|| jdtjdtj
d
d¡|¡}| jd|tjjd}| jd|| jd|tjdgtj
d
ddd}t  | || jdt ddg¡d¡}| jdt | |dg¡ddgd}t  | || jdt dg¡d¡}| jd|tjjd}|S )a!  Generate paddings in ONNX order based on pad in pytorch.

    Args:
        input: the input tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ..., dim_m_begin, dim_m_end,
            where m is in range [0, n].
    rÈ   r   rx   rÉ   rQ   rT   NÚSizery   rR   ÚSubrÜ   é   rc   rd   r   rƒ   rs   rz   Ú	Transpose©Zperm_i)r	   r…   Z_is_listZ_is_scalar_listr[   rn   r>   r\   r]   rj   Úint64Ú_C_onnxÚTensorProtoDataTypeZINT64r‘   Úopset10Úflip)rM   rp   r3   Zpad_lenr–   Ú	extensionÚpaddingsZ	padding_cr`   r`   ra   Ú_prepare_onnx_paddingsº  sR    ÿþý 
"ý  ÿú  ÿ   ÿrú   zaten::constant_pad_ndc                 C  s:   d}t  |¡}t  ||¡}t| ||ƒ}| jd||||dS )NÚconstantÚPad©r™   )r	   r©   Ú_if_scalar_type_asrú   r[   )rM   rp   ÚpaddingÚvaluer£   r3   r`   r`   ra   r   ô  s
    
zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  s"   d}t | ||ƒ}| jd|||dS )NÚreflectrü   rý   ©rú   r[   ©rM   rp   rÿ   r£   rù   r`   r`   ra   r7   þ  s    zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  s"   d}t | ||ƒ}| jd|||dS )NÚedgerü   rý   r  r  r`   r`   ra   r:     s    z	aten::pad©rM   rp   r3   r£   r   c                 C  sv   t  |d¡}|dkr t| ||ƒS |dkr4t| ||ƒS |dkrJt| |||ƒS |dkr`t | ||¡S t d|› |¡‚d S )NÚsZ	replicater  rû   ZcircularzUnrecognized padding mode )	r	   r‰   r:   r7   r   rn   Z_pad_circularr   r’   r  r`   r`   ra   r3     s    	zaten::linalg_detc                 C  s   |   d|¡S )NZDetru   ©rM   rN   r`   r`   ra   r+   (  s    zaten::logdetc                 C  s   t  | t| |ƒ¡S r¡   )rn   Úlogr+   )rM   rp   r`   r`   ra   r-   .  s    úaten::arangec                 G  s   dd„ }t |ƒdkrŒtdd„ |D ƒƒrŒtj}| jdtj|d |dd	}| jdtj|d
 |dd	}| jdtjd
|dd	}|  d|||¡S t |ƒdks¦t |ƒdkr(t |ƒdkr¸d }n||d
 ƒ}tj| |d |d\}}}}| jdtjd| ¡ dd	}	| jdtjd
| ¡ dd	}|  d|	||¡S t |ƒdksDt |ƒdkrœt |ƒdkrXd }n||d ƒ}tj| |d |d
 |d |d\}
}}}|  d|||¡S t |ƒdkr||d ƒ}tj| |d |d
 |d\}}}}| jdtjd
| ¡ dd	}|  d|||¡S t 	ddt |ƒ› d¡S d S )Nc                 S  s   t  | d¡} | S )Nrq   )r	   r¥   rR   r`   r`   ra   Ú_get_arange_dtype7  s    z!arange.<locals>._get_arange_dtyperð   c                 s  s   | ]}t |tƒV  qd S r¡   )Ú
isinstancerž   )r}   Úvalr`   r`   ra   Ú	<genexpr>;  s     zarange.<locals>.<genexpr>rQ   r   rR   rT   rx   ÚRangerœ   )rì   rS   r˜   é   r›   )ré   rì   ÚsteprS   ro   )ré   rì   rS   r	  zwith z
 arguments)
rŠ   Úallr\   ró   r[   r]   r	   Z_arange_cast_helperrS   rš   )rM   r•   r
  rS   ré   rì   Zdelta_defaultÚtype_r  Zstart_defaultrÛ   r`   r`   ra   r   4  s~    þþþ  ÿþþ    ÿ   ÿþ ÿzaten::_dim_arangec                 C  sT   |   d|¡}| j d|| j dt |¡ddd}t ¡ rB|   d|¡S t| |dd d d ƒS )	Nry   rr   rQ   rT   r   rs   z_caffe2::Ranger˜   )r[   r\   r]   r	   r‡   r   )rM   Úlikerv   Z
like_shapeÚstopr`   r`   ra   Ú_dim_arange}  s       ÿr  z
aten::size)Zquantize_outputc                 C  s"   |d kr|   d|¡S t | ||¡S )Nry   )r[   r	   Ú_size_helper©rM   rN   rv   r`   r`   ra   r>   Š  s    zaten::squeezec                 C  s|  |d kr|   d|¡S t |¡s.t | ||g¡S t |dd¡}t |¡}|}|d k	rb|dk rb||7 }t ||¡}|dk r~|d ksˆ|d kr,| j dt |g¡d}t 	| ||¡}| j dtj
dtjdd}|   d	||¡}	tj| d
|	dd\}
\}}}t |||g¡}t |j|¡ |  d|¡}t |j|¡ |
S |}|dkrlt dt|ƒ d d t|ƒ d d d ¡ |S t | ||g¡S )NZSqueezerq   rv   r   rQ   rT   rx   rR   ÚEqualÚIfrð   )Ún_blocksZIdentityz5This model contains a squeeze operation on dimension z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z7input shapes, please export with dynamic_axes argument.)r[   r	   Ú_is_constantr·   r¯   rj   Z_get_tensor_dim_sizer\   r]   r  Zonesró   r   Úadd_op_with_blocksr   Ú_add_output_to_blockÚblockÚwarningsÚwarnr   )rM   rN   rv   Z
input_rankZadjusted_dimÚdim_sizeZdim_constantr>   Ú	const_oneZcondZif_opZ
if_contextZelse_contextrÛ   Zsqueeze_Z	identity_r`   r`   ra   rB   “  s^    

   ÿ
ÿþýüûúùÿ
zaten::unsqueezec                 C  s(   t  |¡rt  |dd¡}t  | ||g¡S )Nrq   rv   )r	   r  r¯   r{   r  r`   r`   ra   rG   Ç  s    
zaten::mmc                 C  s   | j d||dddS )NZGemmg        g      ð?)Zbeta_fZalpha_fru   )rM   rN   rÁ   r`   r`   ra   r0   Ð  s    zaten::indexc                 C  sš   t  ¡ r| jd||ddS t  |¡r0t  |¡}n|g}t|ƒdkrŒ|d }t  |¡sŒt  |¡srtj	 
|¡tj	jkrŒt | |¡}|  d||¡S t | ||¡S )Nr'   ZTensorr§   rx   r   r±   )r	   r‡   rˆ   r…   r†   rŠ   re   rŒ   r   rX   rY   ÚUINT8rn   r²   r[   r'   )rM   rN   r'   rÑ   r`   r`   ra   r'   Ö  s"    

ÿ
ÿþzaten::index_fillc           	      C  st   t  |d¡}t  ¡ r*| jd|||d|dS t  | |||¡\}}t  |¡}t  ||¡}t | ||d ¡}t	| ||||ƒS )Nrq   r&   Z
int_Scalar)r¨   Údim_i)
r	   r‰   r‡   rˆ   Ú_index_fill_reshape_helperr©   rþ   rn   r|   r<   )	rM   rN   rv   r'   r   Ú	dim_valueÚexpanded_index_shapeÚexpanded_indexZexpanded_valuer`   r`   ra   r&   î  s(    ú	   ÿ
zaten::index_copyc                 C  sL   t  |d¡}t  ¡ r(| jd||||dS t  | |||¡\}}t| ||||ƒS )Nrq   r(   )r$  )r	   r‰   r‡   rˆ   r%  r<   )rM   rN   rv   r'   rµ   r&  r'  r(  r`   r`   ra   r(     s       ÿzaten::bitwise_right_shiftzaten::__rshift_c                 C  sÚ   t j |t jj¡t j |¡kr:| jd|t j |¡ ¡ d}t j |t jj¡t jjkrf| jd||ddS | jdtjdtj	dd	}t
 |¡sž| jd|tjjd}|  d
||¡}| jd|t j |¡ ¡ d}|  d||¡}|S )Nrc   rd   ÚBitShiftZRIGHT©Zdirection_srQ   rð   rR   rT   ÚPowÚDiv©r   rX   rY   ri   r[   rf   r#  r\   r]   Zfloat32r	   rÞ   rô   rõ   rZ   )rM   rN   rÁ   ÚtwoÚtwo_powÚrshiftr`   r`   ra   Ú	__rshift_  s6     ÿ
þýÿÿ
ýr1  zaten::bitwise_left_shiftzaten::__lshift_c                 C  sÚ   t j |t jj¡t j |¡kr:| jd|t j |¡ ¡ d}t j |t jj¡t jjkrf| jd||ddS | jdtjdtj	dd	}t
 |¡sž| jd|tjjd}|  d
||¡}| jd|t j |¡ ¡ d}|  d||¡}|S )Nrc   rd   r)  ZLEFTr*  rQ   rð   rR   rT   r+  rÜ   r-  )rM   rN   rÁ   r.  r/  Úlshiftr`   r`   ra   Ú	__lshift_4  s6     ÿ
þýÿÿ
ýr3  c                 C  sâ   |   d|| j dt |d ¡d¡}|   d|| j dt ||d  ¡d¡}|   d| j dt d¡d|| j dt |¡d¡}t d|| |¡}| j d| d¡d}t | |dg¡}t | || j dt d	dg¡d¡}	|   d||	¡}
|
S )
Nrã   rQ   rð   rT   rï   rx   r  r   rz   )r[   r\   r]   r   rG   r	   r{   r‘   )rM   Zinput_dZkernel_size_dZ
dilation_dZ	padding_dZstride_dZblocks_dZblocks_d_indicesZkernel_gridZkernel_maskZ
block_maskr`   r`   ra   Ú_get_im2col_indices_along_dimW  s<    
  ÿýü  ÿ  ÿr4  c                 C  s.   | j dt dd||gd ¡d}|   d||¡S )NrQ   r   rð   rT   rü   )r[   r\   r´   )rM   rp   Ú	padding_hÚ	padding_wr3   r`   r`   ra   Ú_get_im2col_padded_inputƒ  s     r7  c              
   C  s˜   t | || jdt d¡dƒ}t | || jdt d¡dƒ}|  d|| jdt || ¡d¡}| jdt | |dg¡t | |dg¡| jdt dg¡dddS )	NrQ   r   rT   rx   rÜ   r   rz   rs   )r>   r[   r\   r]   r	   r{   )rM   rp   Úkernel_hÚkernel_wZ	batch_dimZchannel_dimZchannel_unfoldedr`   r`   ra   Ú_get_im2col_output_shapeŒ  s      ÿûr:  zaten::im2colÚisc              	   C  s  t | || jdt d¡dƒ}t | || jdt d¡dƒ}|d |d  }}	|d |d  }
}|d |d  }}|d |d  }}t| ||||
|ƒ}t| |||||	ƒ}t| |||ƒ}t| ||
|ƒ}| jd||dd}| jd||d	d}| jd
|dddd	ddgd}t | ||¡S )NrQ   rð   rT   r›   r   rx   rr   rs   r˜   rñ   rœ   rò   )	r>   r[   r\   r]   r4  r:  r7  r	   r‘   )rM   rp   Zkernel_sizeZdilationrÿ   ZstrideZinput_hZinput_wZstride_hZstride_wr5  r6  Z
dilation_hZ
dilation_wr8  r9  Zblocks_row_indicesZblocks_col_indicesZoutput_shapeZpadded_inputÚoutputr`   r`   ra   r%     s8         ÿ     ÿzaten::narrowc                 C  s"   |   d||¡}tj| ||||dS )Nrã   r‚   )r[   r	   rŽ   )rM   rp   rv   ré   Úlengthrì   r`   r`   ra   r1   Ó  s    zaten::flattenc                 C  sº   t  |¡}|dkr|S |dkrL|dks:|d k	r„||d kr„| jd||dS n8|dkr„|dksp|d k	r„||d kr„| jd||d dS |d kr˜t  dd	¡S |dk r¨|| }t  | ||||¡S )
Nrx   rz   ZFlattenrs   r   éþÿÿÿrð   rv   zfONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.)r	   rj   r[   rš   Z_flatten_helper)rM   rp   Z	start_dimZend_dimrv   r`   r`   ra   r!   Ú  s"    
þzaten::linalg_vector_normrw   zOptional[Sequence[int]]Úbool)rM   rv   Úkeepdimc                 C  s   t  | |||||¡S r¡   )r	   Z_linalg_vector_norm_helper)rM   rN   Úordrv   r@  rS   r`   r`   ra   r,   ö  s    zaten::embedding_bagc
           
      C  s   t  | |||||||||	¡
S r¡   )r	   Z_embedding_bag_helper)
rM   Zembedding_matrixrÑ   ÚoffsetsZscale_grad_by_freqr£   ÚsparseZper_sample_weightsZinclude_last_offsetZpadding_idxr`   r`   ra   r     s    özaten::embedding_renormc              	   C  sà   |   d|¡}|   d||¡}t|ƒ}|dkr0d}n"|dkr>d}nt d|› d|¡‚| j ||dgdd	}|   d
|| j dt d¡d¡}	t |¡}|   d||	¡}
|   d||
¡}|   d|   d||¡||¡}|   d|t | |dg¡|¡S )NrÊ   rr   rx   ZReduceL1rð   ZReduceL2z8Unsupported: ONNX export of embedding_renorm with norm: z. Only 1. and 2. are supported.)Úaxes_irí   rã   rQ   gH¯¼šò×z>rT   r,  rÜ   ZWhereZGreaterr„   )r[   rž   r   r’   r\   r]   r	   r{   )rM   ÚweightrÑ   Zmax_normZ	norm_typeZunique_indicesZpartial_weightZnorm_iZpartial_weight_normZpartial_weight_norm_ÚscalesZpartial_weight_renormr`   r`   ra   r    !  s@    
ý  ÿ
üüzaten::chunkc              
   C  s¢   | j d|   d|¡|dd}|   d|| j dtjdgtjdd	¡}|   d
|   d||¡|¡}t | ||d ¡|   d||   d||¡¡g}| j d|žddiŽ}t| |||ƒS )Nrr   ry   r   rs   rï   rQ   rx   rR   rT   r,  rã   rÜ   r   rt   )r   )r[   r\   r]   rå   rn   r|   rA   )rM   rN   Úchunksrv   r!  Zchunk_size_sÚ
chunk_sizeZ	chunk_vecr`   r`   ra   r   I  s      ÿþzaten::normalc	           
      C  sD   |d k	r"t  |¡s"t | ||d ¡}t | ||  d|¡¡}	t| |	|ƒS )NZRandomNormalLike)r	   re   rn   r|   rà   r[   r   )
rM   ZmeanZstdÚsizesÚ	generatorrS   ZlayoutZdeviceZ
pin_memoryr—   r`   r`   ra   r2   [  s    zaten::atleast_1dztorch._C.Valuer  c              
   C  s°   t  |¡rzt  |¡rzt  |¡}g }|D ]D}|}t  |¡}|dkr`t  | || jdt dg¡d¡}| 	|¡ q&| jd|žŽ S t  |¡}|dkr¬t  | || jdt dg¡d¡}|S )Nr   rQ   rx   rT   ÚSequenceConstruct)rK  )
r	   rª   r…   r†   rj   r‘   r[   r\   r]   r   ©rM   rN   r¾   Znew_tensor_listr]   Z
new_tensorZtensor_rankr`   r`   ra   r   s  s,    

  ÿ
  ÿzaten::atleast_2dc                 C  sì   t  |¡r˜t  |¡r˜t  |¡}g }|D ]b}|}t  |¡}|dkrdt  | || jdt ddg¡d¡}n|dkr~t j	| |dgd}| 
|¡ q&| jd|žŽ S t  |¡}|dkrÎt  | || jdt ddg¡d¡}n|dkrèt j	| |dgd}|S )Nr   rQ   rx   rT   ©rD  rK  )rK  ©r	   rª   r…   r†   rj   r‘   r[   r\   r]   r{   r   rL  r`   r`   ra   r   Ž  s<    

  ÿ  ÿ
  ÿzaten::atleast_3dc                 C  sR  t  |¡rÈt  |¡rÈt  |¡}g }|D ]’}|}t  |¡}|dkrft  | || jdt dddg¡d¡}nH|dkr”t j	| |dgd}t j	| |dgd}n|dkr®t j	| |dgd}| 
|¡ q&| jd	|žŽ S t  |¡}|dkrt  | || jdt dddg¡d¡}nL|dkr2t j	| |dgd}t j	| |dgd}n|dkrNt j	| |dgd}|S )
Nr   rQ   rx   rT   rM  rz   rð   rK  )rK  rN  rL  r`   r`   ra   r   °  sX    

  ÿ  ÿ  ÿ  ÿ

  ÿ

zprim::ConstantChunkc              
   C  s  |   d|¡}| j dtj|gtjdd}| j d||dd}| j dtjdgtjdd}| j dtj|gtjdd}| j dtj|d gtjdd}	|   d	||	¡}
|   d
|
|¡}g }t|ƒD ]N}| j dtj|d gtjdd}|   d||¡}| |   d||||¡¡ |}qº|S )Nry   rQ   rR   rT   rr   r   rs   rx   rã   r,  rÜ   rä   )r[   r\   r]   rå   r‹   r   )rM   rN   rG  rv   Zinput_shaperê   Zinput_shape_dimré   rH  Zchunk_size_minus_1Zinput_shape_dim_shiftZ	chunk_dimrë   rq   r'   rì   r`   r`   ra   r6   Ý  s$     ÿ zaten::hstack©rM   r¾   c              
   C  sÜ   t | |ƒ}|  d|| jdtjdtjdd¡}|  d|¡}|  d|¡}| jdtjdtjdd}|  d	||¡}tj| d
|ddd\}\}}	}
|jd|ddd}t |j	|¡ |	jd|ddd}t |	j	|¡ | 
¡  ¡ }|S )Nr¹   rQ   r   rR   rT   ry   rî   rx   r  r  rð   )r  rÌ   rÈ   rÉ   )r   r[   r\   r]   rå   r   r  r   r  r  r­   r<  )rM   r¾   Zfirst_tensorZfirst_tensor_shapeZfirst_tensor_dimr"  Zequal_to_oneZif_op_greaterZif_context_equalZelse_context_equalrÛ   Z	result_ifZresult_elser—   r`   r`   ra   r$   ó  s>    
ýü   ÿ   ÿzaten::vstackc                 C  s   t | |ƒ}| jd|dddS )NrÈ   r   rÉ   )r   r[   rO  r`   r`   ra   rH     s    
)F)F)N)N)N)N)N)r   )N)N)r   N)N)N)N)NNNNNN)iÚ__doc__Ú
__future__r   Ú	functoolsr   r  Útypingr   r   r\   r   Ztorch._Cr   rô   Z
torch.onnxr   r   r	   r
   rö   r   rn   r   Ztorch.onnx._internalr   r   r   Ú__all__ÚpartialZonnx_symbolicZ_onnx_symbolicZquantized_argsÚ
parse_argsrh   r#   r   r   r   r8   r=   r)   r4   Z_apply_paramsr¢   r¤   r"   r<   r   r/   r.   r¸   rº   r¿   r   r   r*   r5   r   r   rC   rÔ   rF   rD   r?   r   r;   r9   rA   r@   rE   rú   r   r7   r:   r3   r+   r-   r   r  r>   rB   rG   r0   r'   r&   r(   r1  r3  r4  r7  r:  r%   r1   r!   r,   r   r    r   r2   r   r   r   r6   r$   rH   r`   r`   r`   ra   Ú<module>   s°   Ç<#

ÿ
þþþþþþþ"

	
$9G

2
  +3%     ÷ +