U
    Mh                    @  s&D  d Z ddlmZ ddlZddlZddlZddlZddlZddlm	Z	m
Z
mZmZmZmZ ddlZddlm  mZ ddlZddlZddlmZ ddlmZmZmZmZmZ ddlmZ ddlmZmZm Z  dd	l!m"Z" d
dddddddddddddddddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:d;d<d=d>d?d@dAdBdCdDdEdFdGdHdIdJdKdLdMdNdOdPdQdRdSdTdUdVdWdXdYdZd[d\d]d^d_d`dadbdcdddedfdgdhdidjdkdldmdndodpdqdrdsdtdudvdwdxdydzd{d|d}d~dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd ddddddddd	d
dgZ#ej$e j%ddZ&ddddZ'ej(dd Z)e&dej(ddddZ*e&dej(ddddZ+e&de,dej(ddddЄZ-e&de,dej(ddddτZ.e&d ej(d<ddd!dZ/e&d"ej(d=ddd#dZ0e&d$ej(d>ddd%dԄZ1e&d&ej(ddd'dZ2e&d(ej(ddd)d<Z3e&d*e4d+d+d+d,ej(d?ddd.dZ5e4d+d+d/ej(ddd0d1Z6ej(ddd2d3Z7ej(ddd4d5Z8e&d6ej(ddd7dLZ9e&d8ej(ddd9dNZ:e&d:ej(ddd;dZ;e&d<ej(ddd=dǄZ<e&d>e4d+d?ej(ddd@d$Z=e&dAe4d+d?ej(dddBdZ>e&dCej(dddDdEZ?e&dFej(dddGdZ@e&dHej(dddId ZAe&dJej(dddKdZBe&dLe4d+d+d+ddej(dddMdZCe&dNej(dddOdZDe&dPej(dddQdZEe&dRej(dddSdӄZFe&dTej,ddUdVdWej(dddXdZGe&dYej(dddZd݄ZHe&d[ej(ddd\d6ZIe&d]ej(ddd^dZJe&d_ej(ddd`dZKe&daej(dddbdZLe&dcej(dddddZMe&deej(dddfdZNe&dgej,ddhddWej(dddidڄZOe&djej(dddkdۄZPe,dej(dddldmZQe&dneRdodpgdqe&dreRdsdtgdqe&duejRdvdwdxdygdqej(d@dddzd{d|d}ZSe&d~e4d+d?dej(dddd9ZTe&dej(ddddZUe&dej(ddddZVe&dej(ddddZWe&de,dej(ddddZXe&de,dej(ddddHZYe&de,dej(dddd"ZZe&de,ddej(ddddGZ[e&de,de4d+d+d?dd+ej(ddddAZ\e&de,de4d+d+d+d?d?d?d+d?d?	ej(dddd@Z]e&dej,ddxdej(dAddddބZ^e&de,de4d+d?d?ej(ddddZ_e&de4d+dej(ddddZ`e&de,dej(ddddZae&dej(ddddZbe&de4d+d?d?d?ej(dBddddZce&de4d+d+d?d?ej(dCddddZde&dej(dDddddZee&de4d+dd?d?ej(dEddddZfe&dej(dFddddZge&de4d+d?d?ej(dGddddZhe&de,de4d+d?d+ej(dddd؄Zie&dej(ddddZje&dej(dHddddZke&dej(ddddZle&dej(dddd܄Zme&dej(ddddZne&de,dej(ddddɄZoe&de,dej(ddddʄZpe&dej(dddd&Zqe&dej(ddddMZre&dej(dddÐdĄZse&dŃe4d+ddej(ddddZte&dǃe,de4d+d,dej(dIddȐdɐdzdʜddlZue&d̃e4d+d?ej(ddddVZve&d΃e4d+d?dej(dJddddZwe&dЃej(ddddZxe&d҃ej(ddU Zye&dejRdejzj{j|j}d֐dxd׍e'dՃgdqe&dejRdejzj{j|j~dڐdxd׍e'dكgdqe&dejRdejzj{j|jdݐdxd׍e'd܃gdqej(dސd߄ Ze&dedejzj{j|j}d֐dd׍Ze&dedejzj{j|j~dڐdd׍Ze&dedejzj{j|jdݐdd׍Ze&deRdejzj{j|j}e'dgdqe&deRdejzj{j|j~e'dgdqe&deRdejzj{j|je'dgdqej(dd Ze&deRddejzj{j|j}e'dgdqe&deRddejzj{j|j~e'dgdqe&deRddejzj{j|je'dgdqe&deRddejzj{j|j}ee'dgdqe&deRddejzj{j|j~ee'dgdqe&deRddejzj{j|jee'dgdqej(dKddZej(ddddZej(dd  Ze&dej(dddd+Zej(ddȐdȐdddZe&de&de&dej(ddd	dȄZe&d
e&de&dej(dddd΄Ze&dej(ddȐdȐdȐdȐdddZe&deRddݐde'dgdqe&deRddde'dgdqe&deRddde'dgdqe&deRdddse'dgdqe&deRdddse'dgdqe&deRdddse'dgdqej(dddd d!d"Ze&d#ej(ddd$d%Ze&d&ej(ddd'dZe&d(ej(d)d Zej(d*d Zej(d+d+d,d-dZe&d.ej(ddd/d0Ze&d1e,ddej(ddd2dDZe&d3e,ddeej(ddd4dZe&d5e,ddej(ddd6dXZej(ddd7d8Ze&d9e,ddej(ddd:dZej(ddd;d<Ze&d=e,ddeej(ddd>dSZe&d?e,ddeej(ddd@dkZe&dAej(dddBdCZe&dDej(dddEdFZe&dGej(dddHdIZe&dJedKej(dddLd{Ze&dMedKej(dddNd}Ze&dOedKej(dddPd~Ze&dQej(dddRd|Ze&dSej(dddTdUZe&dVej(dddWdXZe&dYe4d+d+d+d?ej(dLdddZdZe&d[e4d+d?dej(dMddd\dvZe&d]e4d+d?d?ej(ddd^d_Ze&d`e4d+d+d+dddd?dd?d?d?d?d?ej(dNdddadbZe&dce4d+d+d+dd/dd?ej(dddddeZe&dfe4d+d+d+dddd?dd?	ej(dddgd5Ze&dhe4d+d+d+dd+dd?ej(dddid1Ze&dje4d+d+d+dd+dd?ej(dddkd2Ze&dle4d+d+d+dd+dd?ej(dddmd3Ze&dne4d+d+d+dddd?dej(dddod.Ze&dpe4d+d+d+dddd?dej(dddqd/Ze&dre4d+d+d+dddd?dej(dddsd0Ze&dte4d+d+d+d+d+d?d,d,d?	ej(dddudZe&dve,ddxdxdxe4d+dd+d+d,ej(ddȐdwdȐdȐdɐdxdydzdZe&d{e,ddxdxdxe4d+dd+d+d,dej(ddȐdwdȐdȐdɐdzdȐd|d}djZe&d~e4d+d+d+d+d+dd,d,d	ej(ddzdddzddddZe&de4d+d?d?d?ej(ddddZe&de,de4d+dddej(dddd?Ze&de,dej(ddddلZe&de4d+d?d+ej(ddddbZe&dej(ddddaZe&dej(dddd`Ze&dej(dddd_Ze&de4d+d+ddej(dOdddd#Ze&dej(ddddZe&de4d+d+d?d,ej(dddd7Ze&dej(ddddZe&dej(dddd*Ze&dej(dddd
Ze&dej(ddddwZe&dej(ddddyZe&dej(ddddxZe&dej(ddddZe&dej(dddd)Ze&de4d+d+ej(dddd(Ze&de4d+d+ej(dddd'Ze&dej(dPddddZe&de,ddej(ddddZe&dej(dQddddZe&de,ddej(ddddZe&de,de4d+dd?ej(ddddZe&de,de4d+dd?ej(ddddZe&de,de4d+d+d?ej(ddddZe&dej(ddddFZe&de&de4d+d,d?ej(dddd>Ze&deRdgdqe&deRdgdqe&deRdgdqe&deRdgdqe&deRdgdqe&deRd¡gdqej(dddÐdĄZe&dŃe4d+ddd?d+ej(dRddddZe&dǃe4d+d+d+d?ej(dddd-Ze&dɃe4d+d?d?ej(dddʐd˄Ze&d̃e4d+d?d?d?ej(ddd͐d΄Ze&dσeאdАdѐdҡej(dddӐdԄZe&dՃeאdАdѐdҡej(ddd֐dׄZe&d؃eאdАdѐdҡej(dddِdڄZe&dۃeאdАdѐdҡej(dddܐd݄Ze&dރeאdАdѐdҡej(dddߐdZe&deאdАdѐdҡej(ddddZe&deאdАdѐdҡej(ddddZe&deאdАdѐdҡej(ddddZe&deאdАdѐdҡej(ddddZe&de4d+d?d+d+d+d+ej(dSddddCZe&de4d+d?d+d+d+d+ej(dTddddBZe&dej(dUddddZe&dej(ddddՄZe&dej(dVddddZe&dej(dWddddZe&de4d+d?d+d+d+ej(dXdddd
Ze&de4d+d?d+d+d+d+ej(dYdddd	Ze&dej(dZddddZe&dej(ddd dZe&de4d+d?d+d+d+ej(d[ddddZe&de4d+d?d+d+d+d+ej(d\ddddZe&dej(d]ddddZe&dej(d^ddddQZe&d	ej(d_ddd
dPZe&dej(d`ddddZe&dej(ddddIZe&dej(dddd߄Ze&de,de4d+d,d,ej(ddȐdɐdɐddd]Ze&de,de4d+ej(dddd\Ze&dej,ddhddWe4d+ej(dddd[Ze&de4d+ej(ddddZe&de4d+d,ej(ddddZZe&de4d+d,ej(ddddZe&dej(ddddZe&d e4d+d?ej(ddd!dZe&d"e4d+d?d?dej(daddd#dZe&d$ej(ddd%dZe&d&e4d+d?d?d?d?dej(dbddd'dZe&d(ej(ddd)d4Ze&d*ej(ddd+dZe&d,ej(ddd-d̈́Z e&d.ej(dcddd/d̄Ze&d0e4d+d?ej(ddd1dZe&d2e4d+d?ej(ddd3dZej(ddddd4d5Ze4d+d+d+d?d?d,d?d?d?	ej(ddd6d7Ze4d+d+d+d+d?d?d,d?d?	ej(ddd8d9Ze&d:ej(ddd;dZe&d<ej(ddd=dZe&d>eRd?e'd@gdqe&dAeRdBe'dCgdqe&dDeRdEe'dFgdqddGdHdIZ	e&dJe4d+d?ej(dddKdLZ
e&dMej(dddNd:Ze&dOe4d+d?ej(dddPd,Ze&dQe4d+d+d?ej(dddRdSZe&dTe4d+d+d?dd+ej(dddUdVZe&dWej(dddXdĄZe&dYej(dddZdÄZe&d[ej(ddd\dƄZe&d]ej(ddd^dZe&d_ej(deddd`dńZe&daej(dfdddbdZe&dce4d+d,d,d?dej(ddddd҄Ze&deej(dgdddfdZe&dge4d+ej(dddhduZe&die4d+ej(dddjdEZe&dke,ddxdxe4d+d?d?ej(dddldKZe&dme4d+ej(dddndZe&doej(dhdddpdZe&dqe4d+ej(dddrdgZe&dsej(dddtduZe&dvej(dddwdxZe&dye4d+d?d?d?ej(dddzdZe&d{e4d+d+dej(dd|d|dzd}d~dZ e&de4d+d+dej(dd|d|dzd}ddZ!e&de4d+d?d+d+ej(ddddׄZ"e&de4d+d?d+d+ej(ddddքZ#e&dej(ddddzZ$e&dej(ddddeZ%e&dej(ddddZ&e&deej(ddddZ'e&dej(ddddZ(e&de4d+d?d+d+ej(diddddRZ)e4d+dd?d?ej(ddddZ*e&dej(ddddZ+e&dej(ddddZ,e&dej(ddddZ-e&dej(ddddZ.e&de4d+dd?ej(ddddZ/e&dej(ddddZ0e&dej(ddddtZ1e&dej(ddddnZ2e&dej(ddddZ3e&dej(ddddZ4e&dej(ddddcZ5e&de4d+d+ddd+ej(dd|d|ddzd|dddqZ6e&de4d+d,ddd+ej(dd|dɐddzd|dddrZ7e&de4d+d+ddd+ej(dd|d|ddzd|dddpZ8e&de4d+d+d?ej(djddddoZ9e&de4d+ddej(dkddddOZ:e&de4d+d?dd+ej(dlddddZ;e&dej(ddddZ<e&de4d+d/ej(dmdddddZ=e&dej(dddd˄Z>e&dÃe4d+d/ej(dndd|ddĜddTZ?e&dƃe,ddxdxdxe4d+d?d+d+d,d?ej(ddddWZ@e&dȃe4d+d+d?ej(dddɐdʄZAe&d˃ej(dddd;ZBe&d̓ej(dddΐdτZCe&dЃej(dddѐd҄ZDe&dӃej(ddddhZEe&dՃej(ddddZFej(dddאd؄ZGej(dddِdڄZHe&dۃe4d+d+d?dej(ddddiZIe&d݃e4d+d+d?ej(ddddZJe&d߃e,de4d+d+dd?ej(doddddZKe&dej(ddddZLe&dej(ddddZMe&dej(ddddsZNe&de4d+dd?d+d+d+d+ej(dpdddddYZOe&dej(ddddZPe&dej(dddd=ZQe&de4d+ddej(ddddZRe&de4d+d+ej(ddddJZSe&dej(dqdddd^ZTe&de4d+ddej(ddddфZUe&de4d+d+d?ej(drdddd8ZVe&dej(dsdddd%ZWe&dej(ddd dmZXe&dej(dddd!ZYe&ddtddddfZZe&dej(ddddZ[e&dej(ddddZ\e&d	ej(ddd
dZ]e&dej(ddddZ^e&dej(duddddZ_e&dej(ddddZ`e&dddddZae&dej(ddddZbe&dej(dddddZce&dej(ddddZde&dej(ddddZee&dej(ddddZfe&dej(ddd dZge&d!ej(ddd"dZhe&d#ej(dd$dd%dZie&d&ej(dd'dd(dZje&d)ej(dd'dd*dZke&d+ej(ddd,dZle&d-ej(ddȐd.d/dZme&d0ej(ddd1dZne&d2e&d3ej(ddȐd4d5d Zoe&d6e&d7ej(ddȐd4d8dZpe&d9ej(dd|d|d:d;dZqdS (v  zhThis file exports ONNX ops for opset 9.

Opset 9 is supported by ONNX release 1.4.1
release on 01/23/19
    )annotationsN)CallableListOptionalSequenceTupleUnion)_C)
_constants_deprecation_type_utilserrorssymbolic_helper)GLOBALS)	_beartype	jit_utilsregistration)Numberabsacosaddaddcmuladdmmaliasamaxaminaminmaxarangeargmaxargmin
as_strided	as_tensorasinatanatan2baddbmm
batch_norm	bernoullibitwise_not
bitwise_orbmmbroadcast_tensorsbroadcast_to	bucketizecatcdistceil	clamp_max	clamp_minclampcloneconstant_pad_nd
contiguousconv_tbcconv_transpose1dconv_transpose2dconv_transpose3dconv1dconv2dconv3dconvert_element_typeconvolutioncoscosine_similaritycrosscumsumdetachdimdivdotdropouteluembedding_bag	embedding
empty_likeemptyeqerfexp	expand_asexpandeyefillflattenfloor_dividefloorfloordivfrobenius_norm	full_likefullgathergegeluget_pool_ceil_paddingglu
group_normgthann_window
hardshrinkhardsigmoid	hardswishhardtanh	index_add
index_copy
index_fill	index_putindex_selectindexinstance_normis_floating_point	is_pinnedisnanitemkl_div
layer_normle
leaky_relulerpliftlinalg_crosslinalg_matrix_normlinalg_normlinalg_vector_normlinearlinspacelog_sigmoidlog_softmaxloglog10log1plog2logical_andlogical_not
logical_orlogical_xorlogit	logsumexp	lstm_celllstmltmasked_fillmasked_fill_matmulmax_pool1d_with_indicesmax_pool2d_with_indicesmax_pool3d_with_indicesmaxmaximummeshgridminminimummishmmmovedimmse_lossmulmultinomialmvnarrownative_layer_normneneg	new_emptynew_fullnew_ones	new_zerosnonzero_numpynonzeronormnumelnumpy_Tone_hot	ones_likeonesonnx_placeholderpadpairwise_distancepermutepixel_shufflepixel_unshufflepowpreluprim_constant_chunkprim_constant_splitprim_constant	prim_dataprim_device
prim_dtypeprim_ifprim_layoutprim_list_constructprim_list_unpack	prim_loopprim_maxprim_min
prim_shapeprim_tolistprim_tuple_construct	prim_typeprim_unchecked_castprim_uninitialized	rand_likerandrandint_likerandint
randn_likerandn
reciprocalreflection_padrelurelu6	remainderrepeat_interleaverepeatreplication_pad
reshape_asreshaperollrrelursqrtrsubscalar_tensorscatter_addscatterselectselusigmoidsignsilusinsizeslicesoftmaxsoftplus
softshrinksortsplit_with_sizessplitsqrtsquaresqueezestackstd_meanstdsubttaketantanh
tanhshrinktensor	thresholdtotopk	transposetrue_dividetype_asunbindunfoldunsafe_chunkunsafe_split_with_sizesunsafe_split	unsqueezeunsupported_complex_operatorsnoop_complex_operatorsunusedvar_meanvarview_asviewwherewrap_logical_op_with_cast_towrap_logical_op_with_negation
zeros_likezeroszero	   )opsetstrnamec                   s    fdd}|S )z5Exports the function in the current global namespace.c                   s   | t   < t  | S N)globals__all__appendfuncr   L/var/www/html/venv/lib/python3.8/site-packages/torch/onnx/symbolic_opset9.pywrapper,  s    

z_export.<locals>.wrapperr!  )r  r#  r!  r  r"  _export)  s    r$  c                 C  s   |  d}|tj  |S )z%Represents "missing" optional inputs.prim::Constant)opsetTyper	   OptionalTypeZofTensor)gnr!  r!  r"  r  4  s    
zaten::_shape_as_tensorzjit_utils.GraphContextr)  c                 C  s   |  d|S NShaper&  r)  inputr!  r!  r"  _shape_as_tensor<  s    r1  zaten::_reshape_from_tensorc                 C  s*   t |tr| jd|ddi}t| ||S )NConcataxis_ir   )r2  )
isinstancelistr&  r   )r)  r0  shaper!  r!  r"  _reshape_from_tensorB  s    
r7  zaten::reshapeTc                 C  s   t | ||S r  )r   _reshape_helperr)  selfr6  r!  r!  r"  r   J  s    zaten::reshape_asc                 C  s   |  d|}t| ||S r,  r&  r   r)  r:  otherr6  r!  r!  r"  r   Q  s    z	aten::addc                 C  sZ   t |r&t |r&t dddd|S |rLt t |dkrL| d||}| d||S )NAddr     z)Add between list of tensors not supported   Mul)r   	_is_value_is_tensor_list _onnx_opset_unsupported_detailed_scalar_maybe_get_scalarr&  r)  r:  r=  alphar!  r!  r"  r   Y  s        z	aten::subc                 C  s4   |r&t t |dkr&| d||}| d||S )Nr@  rA  Sub)r   rE  rF  r&  rG  r!  r!  r"  r   e  s    z
aten::rsubc                 C  s   t | |||dS )N)rH  )r   rG  r!  r!  r"  r   m  s    z	aten::mulc                 C  s4   t |r"t |r"| d||S | d||S d S )NAndrA  )r   _is_boolr&  r)  r:  r=  r!  r!  r"  r   s  s    z	aten::divc                 G  s.   t |dkrt| ||S t| ||f| S d S Nr   )lenr  _div_rounding_mode)r)  r:  r=  argsr!  r!  r"  rF   }  s    zaten::addcmulvf      ?c              	   C  s2   | j dt|gd}t| |t| t| |||S NConstantZvalue_t)r&  torchr   r   r   )r)  r:  Ztensor1Ztensor2valueZ
value_tensr!  r!  r"  r     s    sc                 C  sT   |d krt | ||S |dkr(t| ||S |dkr<t| ||S td| d|d S )NrW   trunczUnsupported rounding mode: "z$". Expected None, "floor" or "trunc")r  _floor_divide_trunc_divider   SymbolicValueError)r)  r:  r=  Zrounding_moder!  r!  r"  rO    s    
rO  c                 C  s   |  d||}| j d|tjjd}tj|tjj}|tjjkrt	|sjt	|rj| j d|tjj
d}q| j d|| d}n| j d|tjj
d}|S )NDivCastZto_i)r&  _C_onnxTensorProtoDataTypeINT64r   JitScalarType
from_value	UNDEFINEDr   _is_fpFLOAT	onnx_type)r)  r:  r=  outscalar_typer!  r!  r"  r\    s      r\  c                 C  s   t |st |r,t| ||}| d|S | d||}| jdtjdtjdd}| dt | ||t | ||}| d|| d	||}| d
|| d| d||}| jdtjdtjdd}	| d	||	}
| d||
S d S )NFloorr^  rU  r   dtyperV  XorrI  rA  rJ  NotEqualr@  )r   rg  r  r&  rW  r   int64Z
_lt_helper)r)  r:  r=  rj  rF   r  negativemodZ
fixup_maskonefixupr!  r!  r"  r[    s     r[  zaten::floor_dividec                 C  s   t | ||S r  )r\  rL  r!  r!  r"  rV     s    zaten::floordivc                 C  s   t | ||S r  )rV   rL  r!  r!  r"  rX     s    zaten::true_dividec                 C  s   t |st |r"| d||S t }tjj}|tjksJ|tj	ksJt
t tj	kr`tjj}| jd||d}| jd||d}| d||S )a  Division where both inputs are cast to floating types

    If both inputs are floating, performs div as usual
    If only one input is a floating type, the other input is cast to its type
    If neither input is a floating type, both inputs are cast to the default scalar type
    r^  r_  r`  )r   rg  r&  rW  get_default_dtypera  rb  rh  floatdoubleAssertionErrorDOUBLE)r)  r:  r=  rk  Zonnx_scalar_typer!  r!  r"  r    s    zaten::reciprocalc                 C  s*   t |s| jd|tjjd}| d|S )Nr_  r`  
Reciprocal)r   rg  r&  ra  rb  rh  r)  r:  r!  r!  r"  r     s    
z	aten::catic                   s   t |}g  |D ]&}t |r.t |ds.q | qt dksJtt fdd D sdt| 	   D ]}| 
| qtt |}| jd|d|iS )a{  Implement concatenation of pytorch tensors in ONNX along the specified `dim` dimension.

    Parameters:
        g (jit_utils.GraphContext): Graph context.
        tensor_list (List[torch.Tensor]): List of tensors to concatenate.
        dim (int): Dimension along which to concatenate the tensors.

    Returns:
        ONNX graph node representing the concatenated tensor.
    r   c                 3  sF   | ]>}t  d  dkp<t |dkp<t |t  d  kV  qdS r   N)r   _get_tensor_rank.0r   Znonempty_tensorsr!  r"  	<genexpr>$  s   zcat.<locals>.<genexpr>r2  r3  )r2  )r   _unpack_list_is_constant_get_tensor_dim_sizer  rN  rz  allnodeZremoveAllInputsZaddInputr&  )r)  tensor_listrE   tensorsr   r!  r  r"  r.   
  s$    
 
zaten::stackc                   s.    fddt |D }jd|d iS )Nc                   s   g | ]}t | gqS r!  r   _unsqueeze_helperr  rE   r)  r!  r"  
<listcomp>7  s   zstack.<locals>.<listcomp>r2  r3  )r2  )r   r  r&  )r)  r  rE   Z
unsqueezedr!  r  r"  r   3  s    z
aten::listc                 C  s   |S r  r!  r}  r!  r!  r"  _list>  s    r  zaten::mmc                 C  s,   | j dtdgd}| j d|||dddS )NrU  r@  rV  Gemm        rS  Zbeta_falpha_fr&  rW  r   )r)  r:  r=  Cr!  r!  r"  r   D  s    z	aten::bmmc                 C  s   |  d||S NMatMulr.  rL  r!  r!  r"  r*   M  s    zaten::matmulc                 C  s   |  d||S r  r.  rL  r!  r!  r"  r   S  s    zaten::addmmc              	   C  sH  d }t |}t |}t |}	|d k	r0|}n|d k	r>|}n|	d k	rJ|	}t |}
t |}dd }|d k	r&||
ds||dr&| d||}|}t |}t |}|dkr| jdtj|| dd}| d	||}|dkr| jdtjt || dd}| d	||}| d
||S | jd|||t |t |dS )Nc                 S  s   | d k	o| |kS r  r!  )rQ  ur!  r!  r"  is_not_none_nork  s    zaddmm.<locals>.is_not_none_nor   r  r@  rU  rm  rV  rA  r>  r  r  )r   _try_get_scalar_typer  r&  rE  rW  r   rn  )r)  r:  Zmat1Zmat2betarH  rk  self_scalar_typeZmat1_scalar_typeZmat2_scalar_typeZ	mat1_rankZ	mat2_rankr  Zres1Zres2r!  r!  r"  r   Y  s\    







 
 z	aten::negc                 C  s   |  d|S )NZNegr.  r}  r!  r!  r"  r     s    z
aten::sqrtc                 C  sT   t j|t jjt jjt jjt jjt jjt jjhkrH| j	d|t
jjd}| 	d|S )Nr_  r`  Sqrt)r   rd  re  rf  UINT8INT8INT16INTrc  r&  ra  rb  rh  r}  r!  r!  r"  r     s     
zaten::rsqrtc                 C  s"   |  dttd|t| |S )Nr^  r@  )r&  r   _if_scalar_type_asrW  r   r   r}  r!  r!  r"  r     s
      z
aten::tanhg      ?   )scaleZ
zero_pointc                 C  s   |  d|S )NTanhr.  r}  r!  r!  r"  r     s    z	aten::sinc                 C  s   |  d|S )NZSinr.  r}  r!  r!  r"  r     s    z	aten::cosc                 C  s   |  d|S )NZCosr.  r}  r!  r!  r"  r@     s    z	aten::tanc                 C  s   |  d|S )NZTanr.  r}  r!  r!  r"  r     s    z
aten::asinc                 C  s   |  d|S )NZAsinr.  r}  r!  r!  r"  r"     s    z
aten::acosc                 C  s   |  d|S )NZAcosr.  r}  r!  r!  r"  r     s    z
aten::atanc                 C  s   |  d|S )NAtanr.  r}  r!  r!  r"  r#     s    zaten::atan2c              
   C  s   |  d||}|  d|}| j dtdd}| j dttjd}|  d||}|  d||  d|||  d	||}|  d
||}	|  d|	||}
|
S )Nr^  r  rU  r   rV  GreaterWherer>  rI  Less)r&  rW  r   mathpi)r)  r:  r=  sloper#   Z
const_zeroZconst_piZ"condition_second_or_third_quadrantZsecond_third_quadrantZcondition_14_or_23_quadrantresultr!  r!  r"  r$     s    zaten::sigmoidg      p?c                 C  s   |  d|S )a  Converts the corresponding PyTorch function into ONNX operators.

    It is not meant to be called directly by a user.

    Args:
        g (jit_utils.GraphContext): Graph context.
        self (Tensor): the input tensor.
    Returns:
        ONNX operator
    Sigmoidr.  r}  r!  r!  r"  r     s    z
aten::signc                 C  s   |  d|S )NZSignr.  r}  r!  r!  r"  r     s    c                 C  sR   t |t |kstt |dkr>|d dkr>|d tjkr>|S | jd||||dS )Nr@  r   Slice)axes_iZstarts_iZends_i)rN  rz  r
   	INT64_MAXr&  )r)  r0  axesstartsendsr!  r!  r"  _slice  s    &r  z	aten::sumZ	ReduceSumsum)Zdecoratez
aten::mean
ReduceMeanmeanz
aten::prodZ
ReduceProdprodF)allow_multi_dim_supportboolZonnx_opr  r  c                 C  s   t | ||S r  )r   Z_reduce_with_dtype_helperr  r!  r!  r"  _reduce_with_dtype  s
      r  zaten::cumsumnonec                 C  sJ   t  r6|  dkr&t dd|S | jd||dS t ddd| d S )Nr%  rC   rn  dim_ir  r?  )r   is_caffe2_aten_fallbackr  kind_unimplementedat_onnx_opset_unsupported)r)  r0  rE   rn  r!  r!  r"  rC   +  s
    zaten::_sample_dirichletc                 C  s8   t  r,t |s t dd|S | d|S t d|S )N_sample_dirichletz#We are not able to export generatorr   r  _is_noner  r  _onnx_unsupportedr)  r:  	generatorr!  r!  r"  r  7  s    
  r  zaten::_standard_gammac                 C  s8   t  r,t |s t dd|S | d|S t d|S )N_standard_gammaznot able to export generatorr  r  r!  r!  r"  r  C  s    
  r  zaten::tc                 C  s6   t |}|d ks|dk r&| d|S | jd|ddS )Nr  Identity	Transpose)r@  r   Zperm_i)r   r  r&  )r)  r:  rankr!  r!  r"  r   P  s    
zaten::numpy_Tc                 C  s8   t |}|d k	sttttd|}| jd||dS Nr   r  r  )r   r  rz  r5  reversedranger&  )r)  r0  ndimpermr!  r!  r"  r   \  s    
zaten::expandc              	   C  s   t |d}t |s,| jdt|d}n2t |r^t | t| |d| jdt	dgd}t
jj}t| ||}t| || jdt	dd}t| | d||||}| d||S )zXImplement the expand function for a pytorch tensor in ONNX according to specified `size`isrU  rV  r   rq  Expandr   _maybe_get_constrB  r&  rW  
LongTensor_is_packed_listr8  r   r   r   rd  rc  r   r   r  )r)  r:  r   Zimplicitrn  r   neg_onesr!  r!  r"  rR   f  s    

 
 zaten::broadcast_toc              	   C  s   t |d}t |s,| jdt|d}n2t |r^t | t| |d| jdt	dgd}t
jj}t| ||}t| || jdt	dd}t| | d||||}| d||S )Nr  rU  rV  r   r  rq  r  r  )r)  r:  r   rn  r   r  r!  r!  r"  r,   |  s    

 
 zaten::expand_asc                 C  s   t |d}t|tjr|j}|tj}g }t|	 D ]J}t
|||||r:|| | jd|j|dd|d}q:| d|}| d||S )Nr   rU  T)keepdimrV  r-  r  )r   r  r4  rW  Tensorrn  r   ry  r  rE   equalr  r  rQ   r  r&  )r)  r:  r=  Zself_t	orig_typedimsdr6  r!  r!  r"  rQ     s    
 zaten::embeddingbc                 C  s<   |rt jrtd||dkr.t jr.td | d||S )NzUnsupported: ONNX export of embedding with scale_grad_by_freq=True for training mode. ONNX does not support scaling the gradients.r   zWarning: ONNX export of embedding with padding_idx >= 0 for training mode. ONNX does not support not updating the embedding vector at padding_idx during training.Gather)r   Zexport_trainingr   r]  warningswarnr&  )r)  weightindicespadding_idxscale_grad_by_freqsparser!  r!  r"  rK     s    
zaten::embedding_bagc
           
      C  sF   t |st dS t  r:| jd|||d|||||	d
S t d|S )Nz%embedding_bag with per_sample_weightsrJ      )outputsZscale_grad_by_freq_iZmode_iZsparse_iZinclude_last_offset_iZpadding_idx_i)r   r  r  r  r  )
r)  Zembedding_matrixr  offsetsr  moder  Zper_sample_weightsZinclude_last_offsetr  r!  r!  r"  rJ     s$    
z
aten::size)Zquantize_outputc                 C  sh   |d kr|  d|S t|ddk rZt|}|d k	rZt|d| }| j dt|d}t| ||S )Nr-  r~  r   rU  rV  )r&  r   r  r  rW  r   Z_size_helperr)  r:  rE   r  r!  r!  r"  r     s    
zaten::transposec                 C  s   ||kr|S t |}|d k	rTtt|}|| ||  ||< ||< | jd||dS t  rp| jd|d||dS td|d S )Nr  r  r   int)overload_nameZdim0_iZdim1_izAUnsupported: ONNX export of transpose for tensor of unknown rank.)	r   r  r5  r  r&  r  r  r   r]  )r)  r:  Zdim0Zdim1r  r  r!  r!  r"  r     s    
zaten::permuter  c                 C  s*   |t tdt|kr|S | jd||dS r  )r5  r  rN  r&  )r)  r:  r  r!  r!  r"  r     s    z
aten::viewc                 C  s   t | ||S r  )r   )r)  r:  r   r!  r!  r"  r    s    zaten::view_asc                 C  s   |  d|}t| ||S r,  r;  r<  r!  r!  r"  r    s    zaten::unsafe_chunkc           	      C  s   |d krt dddd|S t ||}|d kr<t dd|S || d | }|g||  }|| }|rp|| | jd||||dS )	Nr  r  r?  'Dynamic number of outputs not supportedunknown dimension sizer@  SplitZsplit_ir3  r  )r   rD  r  r  r  r&  )	r)  r:  chunksrE   _outputsr   
split_sizesplitsleftoverr!  r!  r"  r  $  s*          
zaten::splitc           
      C  s   t ||st dddd|S t | d}| dkrJt| ||||S t |dd}t ||}|d kr|d k	r~|| }nt dddd	|S |g||  }|| }	|	r|	|	 | j
d
||||dS )Nr   r  r?  r  rX  r   r~  r  z$Unknown dimension size not supportedr  r  )r   _is_split_staticrD  	_node_getr  rE   r   
_get_constr  r  r&  )
r)  r:  split_size_or_sizesrE   r  Z	split_valr  r   r  r  r!  r!  r"  r   9  s8        
    
zaten::unsafe_splitc                 C  s   t | ||||S r  )r   )r)  r:  r  rE   r  r!  r!  r"  r  U  s    zaten::split_with_sizesc                 C  s2   t ||st dddd|S | jd||||dS )Nr   r  r?  r  r  r  )r   r  rD  r&  r)  r:  Zsplit_sizesrE   r  r!  r!  r"  r   ]  s        zaten::unsafe_split_with_sizesc                 C  s   t | ||||S r  )r   r  r!  r!  r"  r  h  s    zaten::unbindc                   s^   |d krt dddd|S jd|dg|  |d}|dkrB|gn|} fdd	|D }|S )
Nr  r  r?  r  r  r@  r  c                   s   g | ]}t | gqS r!  )r   _squeeze_helper)r  rj  r  r!  r"  r  {  s    zunbind.<locals>.<listcomp>)r   rD  r&  )r)  r:  rE   r  r  Zsqueezed_outputsr!  r  r"  r  p  s        zaten::selectc                 C  st   t |}t |s^|dk r^|dkr,tj}n|d }t j| ||g|g|gd}t | ||gS | jd|||dS dS )zImplement the select functionality for a pytorch tensor in ONNX.

    Selects elements from the input tensor along the specified `dim` dimension based on the `index` tensor.
    r   r  r@  r  r  r  r  r3  N)r   rF  rB  r
   r  _slice_helperr  r&  )r)  r:  rE   rm   Z	end_indexZ
slice_noder!  r!  r"  r     s    	
    zaten::squarec                 C  s   |  d||S NrA  r.  r}  r!  r!  r"  r     s    zaten::squeezec                 C  sJ  |d kr|  d|S t|dd}|dk rt|}|d k	rxtdt| d d d t||  d	 d
  ||7 }ntdd|S t||}|d krtdt| d d t| d d d d  tj	| ||gdS |dkrtdt| d d t| d d d d  |S tdt| d d  tj	| ||gdS )NZSqueezer~  rE   r   z'ONNX export squeeze with negative axis - might cause the onnx model to be incorrect. (Negative axis is not supported in ONNX. Axis is converted to & based on input shape at export time. CPassing an tensor of different rank in execution will be incorrect.r   %negative axis with unknown input rankz5This model contains a squeeze operation on dimension z on an input z7with unknown shape. Note that if the size of dimension z of the input zVis not 1, the ONNX model will return an error. Opset version 11 supports squeezing on zMnon-singleton dimensions, it is recommended to export this model using opset zversion 11 or higher.r  r@  z. The size of z%this dimension in the given input is z. The model will zWbe exported without the squeeze node. If the model is intended to be used with dynamic z-input shapes, please use opset version 11 to zexport the model.z. If the model is z_intended to be used with dynamic input shapes, please use opset version 11 to export the model.)
r&  r   r  r  r  r  r  r  r  r  )r)  r:  rE   Zsqueeze_dimr  dim_sizer!  r!  r"  r     s    



  
zaten::preluc              	   C  s   t |}t |}t|}|d k	rp|dkrJt | |ttd|d }n&|dkrp|dgkrpt | |dg}d}|d k	r|d k	r||kstd| d| | 	d||S )Nr  r@  r   z)rank(x) should be >= rank(slope) but got z < PRelu)
r   r  _get_tensor_sizesrN  r  r5  r  r  rz  r&  )r)  r:  r  	self_rankZweight_sizesZweight_rankr!  r!  r"  r     s&    

  z
aten::siluc                 C  s   |  d||  d|S )NrA  r  r.  r/  r!  r!  r"  r     s    z
aten::mishc                 C  s   |  d||  d|  d|S )NrA  r  Softplusr.  r/  r!  r!  r"  r     s    z
aten::reluc                 C  s   t j| d|ddS )NRelu   opset_beforer   _op_with_optional_float_castr/  r!  r!  r"  r     s       zaten::relu6c                 C  s   t | |ddS )Nr      )r3   r/  r!  r!  r"  r     s    z
aten::ceilc                 C  s   |  d|S )NCeilr.  r/  r!  r!  r"  r0     s    zaten::floorc                 C  s   |  d|S )Nrl  r.  r/  r!  r!  r"  rW     s    z	aten::lenc                 C  s.   t | || jdtdgd}t| |dgS NrU  r   rV  )r   r&  rW  r  r   r  )r)  r:  Zsz_0r!  r!  r"  _len  s    r  zaten::thresholdc                 C  sD   t |dkrt dd|S t |dkr8t dd|S | d|S )Nr   r   znon-zero thresholdznon-zero valuer  )r   rE  r  r&  )r)  r:  r   rX  r!  r!  r"  r   &  s
    zaten::leaky_relu_C.Valuerx  r)  r0  Znegative_slopeZinplacec                 C  s   | j d||dS )N	LeakyRelur  r.  r   r!  r!  r"  rv   2  s    z	aten::gluc                 C  sP   t ||}|d k	r$|d dks$t| jd||dd\}}| d|| d|S )Nr  r   r  )r3  r  rA  r  )r   r  rz  r&  )r)  r0  rE   r  firstsecondr!  r!  r"  r`   @  s
    zaten::softmaxc              
   C  sb  t |}|d k	r|dk r"|| }||d k}|rptt|}|d ||  ||< |d< | jd||d}|d }| jd||d}|r|  dkrt |d	d
}| jd|t	|
 d}|r| jd||d}|S | d|| jd||gdd}| d|}	t j| |	|gd}
| d|	|
}|r^|  dkr^t |d	d
}| jd|t	|
 d}|S )Nr   r@  r  r  r  ZSoftmaxr  r%  r~  rn  r_  r`  rI  	ReduceMaxr  
keepdims_iExpr  r^  )r   r  r5  r  r&  r  r  r  r   rd  ri  _reducesum_helper)r)  r0  rE   rn  	input_dimis_transpose_requiredr  r   parsed_dtyperP   r  r!  r!  r"  r   L  sB    
  zaten::softplusc                 C  s@   t |d}|dkr4| d| d| d|||S | d|S )NrR  r@  r^  r  rA  )r   r  r&  )r)  r:  r  r   Z
beta_constr!  r!  r"  r     s     zaten::get_pool_ceil_paddingc                   s   t | }|d k	r$|t d  nd d ksBtdd D rPt dd| S fddtdtD   fddtdt D   fd	dtdtD fd
dtdtD S )Nc                 s  s   | ]}|d kV  qd S r  r!  r  r~  r!  r!  r"  r    s     z(get_pool_ceil_padding.<locals>.<genexpr>r_   input size not accessiblec              	     sB   g | ]:}t t | d |   |  t|  d qS r  r@  )r  r  r0   rx  r-  )rE   kernel_sizepaddingstrider!  r"  r    s   0z)get_pool_ceil_padding.<locals>.<listcomp>r   c                   sD   g | ]<} | d  |  | |  kr8 | d  n | qS r@  r!  r-  )ceiled_output_dimrE   r1  r2  r!  r"  r    s   "c                   sP   g | ]H}| d krdn2| | d|    | d  |  d    qS )r@  r   r  r!  r-  )r4  rE   r0  r1  r2  r!  r"  r    s   

c                   sd   g | ]\}| d |    | krT|  | d k rDt | q^t  | d n
t | qS r/  r  r-  )r0  r1  padding_ceilr!  r"  r    s   
)r   r  rN  anyr  r  )r0  r0  r2  r1  sizesr!  )r4  rE   r0  r1  r6  r2  r"  r_     s*    
  
zaten::max_pool1dZ
max_pool1dr@  )return_indiceszaten::max_pool2dZ
max_pool2dr  zaten::max_pool3dZ
max_pool3d   c              	     sD   t ddddddt ddddddtj fdd}|S )NTFrQ  r  r~  c                   s<  t |dhkr t d|S |s(|}t|}|rdt||||}|tdd t||D  }n|d }|||d}r| jd|fddi|\}	}
| jd|dd	d
 tD dd
 tD d\}}tj| |dd
 tD t	dt	dd}t
| |
|}
|	|
fS | jd|fddi|}	|	S d S )Nr@  dilationc                 s  s   | ]\}}|| V  qd S r  r!  r  ar  r!  r!  r"  r    s     z1_max_pool.<locals>.symbolic_fn.<locals>.<genexpr>r  )kernel_shape_ipads_i	strides_iMaxPoolr  c                 S  s   g | ]}d qS r3  r!  r  _r!  r!  r"  r    s     z2_max_pool.<locals>.symbolic_fn.<locals>.<listcomp>c                 S  s   g | ]}d qS r3  r!  rB  r!  r!  r"  r    s     )r  r>  r@  c                 S  s   g | ]}d | qS )r  r!  r-  r!  r!  r"  r    s     r   r  )setr   r  tupler_   zipr&  r  r  r5  r   )r)  r0  r0  r2  r1  r;  	ceil_moder6  kwargsrr  rC  Zflattened_indicesrY  r  ndimsr9  tuple_fnr!  r"  symbolic_fn  sB    


z_max_pool.<locals>.symbolic_fnr   quantized_args
parse_argsr   beartype)r  rL  rK  r9  rM  r!  rJ  r"  	_max_pool  s
    4rR  zaten::max_pool1d_with_indiceszaten::max_pool2d_with_indiceszaten::max_pool3d_with_indiceszaten::avg_pool1dZ
avg_pool1dzaten::avg_pool2dZ
avg_pool2dzaten::avg_pool3dZ
avg_pool3dc                   sJ   t dt dddddddtjdddddd	d	d
 fdd}|S )NTrQ  r  r~  r  r  Sequence[int]zUnion[int, Sequence[int]]r  )r0  r0  r2  r1  rG  count_include_padc              	     s   |s|}t |||| }t|ts*t|}|r^t j| d|d| d dddd}dt| }|rt||||}	|td	d
 t|	|D  }n|d }| j	d||||d}
|
S )NPad)r   r   r  constantr  r?  r?  mode_sZvalue_fr  r   c                 s  s   | ]\}}|| V  qd S r  r!  r<  r!  r!  r"  r  {  s    z1_avg_pool.<locals>.symbolic_fn.<locals>.<genexpr>AveragePool)r>  r@  r?  )
r   Z_avgpool_helperr4  rE  rz  r  rN  r_   rF  r&  )r)  r0  r0  r2  r1  rG  rT  Zdivisor_overrideZadjusted_paddingr6  outputr  rL  r!  r"  rM  U  sJ         
	
z_avg_pool.<locals>.symbolic_fn)NrN  )r  rL  rM  r!  r\  r"  	_avg_pool>  s    	 &1r]  zaten::adaptive_avg_pool1dZadaptive_avg_pool1drZ  zaten::adaptive_avg_pool2dZadaptive_avg_pool2dzaten::adaptive_avg_pool3dZadaptive_avg_pool3dzaten::adaptive_max_pool1dZadaptive_max_pool1drA  zaten::adaptive_max_pool2dZadaptive_max_pool2dzaten::adaptive_max_pool3dZadaptive_max_pool3dc                   s(   t ddtj fdd}|S )NTFc              	     s  }zt dW n  tk
r4   t d| Y S X dgt kr\dkr\| d|S t |}z|dd   W n tk
r   d  Y nX  d kstdd  D rڈdgt kr| d	|d fS t d
|S  fddt	dt D }|dgt| kr>dgt kr0| d	|d fS t d|S  fddt	dt D }dkr| |||dt  dt  dS | j|||d}|S )Nr  z4adaptive pooling, since output_size is not constant.r@  rZ  ZGlobalAveragePoolr  c                 s  s   | ]}|d kV  qd S r  r!  r-  r!  r!  r"  r    s     z6_adaptive_pool.<locals>.symbolic_fn.<locals>.<genexpr>ZGlobalMaxPoolr.  c                   s   g | ]} | |  qS r!  r!  r-  rE   output_sizer!  r"  r    s     z7_adaptive_pool.<locals>.symbolic_fn.<locals>.<listcomp>r   z-output size that are not factor of input sizec                   s    g | ]}t  | |  qS r!  r5  r-  r^  r!  r"  r    s     rA  rY  r3  F)r>  r@  )
r   
_parse_arg	Exceptionr  rN  r&  r  r7  r  r  )r)  r0  r_  Zoutput_size_valuer8  rt  kr[  fnr  rL  typer^  r"  rM    sJ     


    
$z#_adaptive_pool.<locals>.symbolic_fn)r   rO  r   rQ  )r  re  rL  rd  rM  r!  rc  r"  _adaptive_pool  s    A
1rf  r  rE   c                 C  sF   t |dd dg| d t|   }|ddd |ddd  }|S )zGenerate paddings in ONNX order based on pad in pytorch.
    Args:
        dim: the dimension of the tensor.
        pad: the paddings in pytorch.
            The order is dim_n_begin, dim_n_end, dim_n-1_begin, dim_n-1_end, ...
    Nr   r  r  )r5  rN  )rE   r   paddingsr!  r!  r"  _prepare_onnx_paddings  s    &rj  c              
   C  sh   t | d}t |rdt |rdt |}zdd |D }W n& tk
rb   t dddd|  Y S X |S )Nr  c                 S  s   g | ]}t |d dqS )r~  r1  )r   r  )r  rQ  r!  r!  r"  r    s    z)_convert_padding_node.<locals>.<listcomp>rU  r  r?  z)The sizes of the padding must be constant)r   r  rB  r  r  ra  rD  )r0  r1  
input_listr!  r!  r"  _convert_padding_node  s     

    
rl  zaten::constant_pad_ndc              
   C  sn   d}zt |dd}W n& tk
r<   t dddd| Y S X t|}tt ||}t j| d||||ddS )	NrV  rR  rX  rU  r  r?  z*The value for the padding must be constantrW  )r   r  ra  rD  rl  rj  r  r  )r)  r0  r1  rX  r  ri  r!  r!  r"  r5   '  s,        
      )r)  r0  r   c                 C  sH  t |}t|d dkstt|d }|}t|D ]}|d| d   }|d| d   }g }	|dkrtj| |d| g| gtjgd}
|	|
 |dk s|dk rt	
d| }t	
d|  }tj| |d| g|g|gd}|	| n
|	| |dkr*tj| |d| gdg|gd}|	| | jd|	dd| i}q4|S )Nr  r   r@  r  r2  r3  )r2  )rl  rN  rz  r  r   r  r
   r  r  builtinsr   r&  )r)  r0  r   r1  r  curidxZpad_rZpad_lr  leftstartendmiddlerightr!  r!  r"  _pad_circular;  sP        


    
ru  zaten::reflection_pad1dzaten::reflection_pad2dzaten::reflection_pad3dc                 C  s2   d}t |}tt||}tj| d|||ddS )NreflectrU  r?  r?  rX  r  rl  rj  r   r  r  r)  r0  r1  r  ri  r!  r!  r"  r   e  s         zaten::replication_pad1dzaten::replication_pad2dzaten::replication_pad3dc                 C  s2   d}t |}tt||}tj| d|||ddS )NedgerU  r?  rw  rx  ry  r!  r!  r"  r   r  s         z	aten::padr)  r0  r   r  rX  c                 C  st   t |d}|dkr t| ||S |dkr4t| ||S |dkrJt| |||S |dkr^t| ||S td| |d S )NrY  Z	replicaterv  rV  ZcircularzUnrecognized padding mode )r   r`  r   r   r5   ru  r   r]  r{  r!  r!  r"  r     s    	zaten::upsample_nearest1dZupsample_nearest1dZnearestzaten::upsample_nearest2dZupsample_nearest2dr  zaten::upsample_nearest3dZupsample_nearest3d   zaten::upsample_linear1dZupsample_linear1dzaten::upsample_bilinear2dZupsample_bilinear2dzaten::upsample_trilinear3dZupsample_trilinear3d)r  rE   interpolate_modec                   s    fdd}|S )Nc                   sb   t | |\}}t  t |}|r8t d|S |d krPt | || }| jd||dS )Nzalign_corners == TrueUpsamplerX  )r   Z_get_interpolate_attributesZ_interpolate_warningrF  r  Z_interpolate_size_to_scalesr&  )r)  r0  r_  rP  scalesalign_cornersrE   r}  r  r!  r"  rM    s"      

   z!_interpolate.<locals>.symbolic_fnr!  )r  rE   r}  rM  r!  r  r"  _interpolate  s    ,r  zaten::__interpolatec           	      C  s*   t | |||||\}}| jd|||dS )Nr~  r  )r   Z _interpolate_get_scales_and_moder&  )	r)  r0  r   Zscale_factorr  r  Zrecompute_scale_factorZ	antialiasr  r!  r!  r"  __interpolate  s         r  zaten::bitwise_notc                 C  s"   t |std|| d|S NzOONNX export does NOT support exporting bitwise Not for non-boolean input valuesrp  r   rK  r   r]  r&  r/  r!  r!  r"  r(     s    
zaten::bitwise_orc                 C  s:   t |std|t |s,td|| d||S )NzVONNX export does NOT support exporting bitwise OR for non-boolean input values. self: zWONNX export does NOT support exporting bitwise OR for non-boolean input values. other: Orr  rL  r!  r!  r"  r)     s    

c                   s    fdd}|S )Nc                   s   t   fdd}|S )Nc                   s,   t  d  } | || |d|| |dS )NZ_cast_F)r  )r)  r0  r=  Zto_cast_func)rd  to_typer!  r"  wrap_with_cast  s    zGwrap_logical_op_with_cast_to.<locals>.decorator.<locals>.wrap_with_cast	functoolswraps)rd  r  r  )rd  r"  	decorator  s    z/wrap_logical_op_with_cast_to.<locals>.decoratorr!  )r  r  r!  r  r"  r    s    r   )r   returnc                   s   t   fdd}|S )Nc                   s   |  d | ||S )Nrp  r.  r)  r0  r=  r  r!  r"  wrap_with_not  s    z4wrap_logical_op_with_negation.<locals>.wrap_with_notr  )r   r  r!  r  r"  r    s    zaten::__not_c                 C  s"   t |std|| d|S r  r  r}  r!  r!  r"  __not_  s    
r  zaten::eqc                 C  s   t | tjr:t | tjr:| jdtjdtjddS | }| }|	 |	   krfdkrn nN|
d|
d  krdkrn n*| jdtj|d|dktjddS | d||S )	NrU  Trm  rV  onnx::ConstantrX  rY  rq  )r4  re  r	   DeviceObjTyper&  rW  r   r  r  r  kindOfrY  )r)  r:  r=  Z	self_nodeZ
other_noder!  r!  r"  rN   $  s"      $zaten::nec                 C  s   t | ||S r  )rN   rL  r!  r!  r"  r   @  s    zaten::gtc                 C  s   t | ||S r  _gt_implr  r!  r!  r"  rb   H  s    c                 C  sJ   t |r<t |r<| jd|tjjd}| jd|tjjd}| d||S )Nr_  r`  r  r   rK  r&  ra  rb  INT32r  r!  r!  r"  r  O  s    r  zaten::ltc                 C  s   t | ||S r  _lt_implr  r!  r!  r"  r   W  s    c                 C  sJ   t |r<t |r<| jd|tjjd}| jd|tjjd}| d||S )Nr_  r`  r  r  r  r!  r!  r"  r  ^  s    r  zaten::gec                 C  s   t | ||S r  r  r  r!  r!  r"  r]   f  s    zaten::lec                 C  s   t | ||S r  r  r  r!  r!  r"  ru   n  s    zaten::__and_c                 C  s:   t |std|t |s,td|| d||S )NzOONNX export does NOT support exporting bitwise AND for non-boolean input valuesrJ  r  r  r!  r!  r"  __and_v  s    

r  zaten::__or_c                 C  s:   t |std|t |s,td|| d||S )NzNONNX export does NOT support exporting bitwise OR for non-boolean input valuesr  r  r  r!  r!  r"  __or_  s    

r  zaten::__xor_c                 C  s:   t |std|t |s,td|| d||S )NzOONNX export does NOT support exporting bitwise XOR for non-boolean input valuesro  r  r  r!  r!  r"  __xor_  s    

r  zaten::logical_andZBoolc                 C  s   |  d||S )NrJ  r.  r  r!  r!  r"  r     s    zaten::logical_orc                 C  s   |  d||S )Nr  r.  r  r!  r!  r"  r     s    zaten::logical_xorc                 C  s   |  d||S )Nro  r.  r  r!  r!  r"  r     s    zaten::logical_notc                 C  s   |  d| j d|tjjdS )Nrp  r_  r`  r&  ra  rb  BOOLr/  r!  r!  r"  r     s    zaten::__rshift_c                 C  s   t j|}t j|t jj|kr6| jd|| d}| jdtjdtjdd}t	
|sn| jd|tjjd}| d||}| jd|| d}| d||}|S )	Nr_  r`  rU  r  rm  rV  Powr^  r   rd  re  rf  r&  ri  rW  r   Zfloat32r   rg  ra  rb  rh  )r)  r:  r=  r  twotwo_powrshiftr!  r!  r"  	__rshift_  s*    
r  zaten::__lshift_c                 C  s   t j|}t j|t jj|kr6| jd|| d}| jdtjdtjdd}t	
|sn| jd|tjjd}| d||}| jd|| d}| d||}|S )	Nr_  r`  rU  r  rm  rV  r  rA  r  )r)  r:  r=  r  r  r  lshiftr!  r!  r"  	__lshift_  s*    
r  zaten::wherec              	   C  s`   t |s| jd|tjjd}|d krPt| |}t | || jdt	dd|S | d|||S )Nr_  r`  rU  r@  rV  r  )
r   rK  r&  ra  rb  r  r   Z_unbind_helperrW  r   )r)  	conditionr:  r=  r  r!  r!  r"  r  	  s    

   zaten::log_softmaxc           	      C  s   t |}|d krt ddS |dk r.|| }||d k}|r|tt|}|d ||  ||< |d< | jd||d}|d }| jd||d	}|r|  d
krt |dd}| jd|t	
| d}|r| jd||d}|S )NrE   fONNX and PyTorch use different strategies to split the input. Input rank must be known at export time.r   r@  r  r  r  Z
LogSoftmaxr  r%  r~  rn  r_  r`  )r   r  r  r5  r  r&  r  r  r  r   rd  ri  )	r)  r0  rE   rn  r*  r+  r  Z	return_opr,  r!  r!  r"  r   	  s2    
  zaten::_log_softmaxc                 C  s>   |r2t j|t jjt jjkr2| jd|tjjd}t	| ||S Nr_  r`  )
r   rd  re  rf  HALFr&  ra  rb  rh  r   )r)  r0  rE   Zhalf_to_floatr!  r!  r"  _log_softmax4	  s     r  zaten::_convolutionc                 C  s"  t |}z|dd  }W n tk
r2   d }Y nX |d ksNtdd |D rZtd|||g}t |st |dkr|| |dd  ||| ||	d}tdd |D r|st	t
|t
|kst	||d< | j|rd	nd
f||}t |st |dkr| d||S |S d S )Nr  c                 s  s   | ]}|d kV  qd S r  r!  r-  r!  r!  r"  r  `	  s     z_convolution.<locals>.<genexpr>DUnsupported: ONNX export of convolution for kernel of unknown shape.r@  )r>  r@  r?  dilations_igroup_ic                 s  s   | ]}|d kV  qdS r  r!  )r  or!  r!  r"  r  x	  s     Zoutput_padding_iZConvTransposeConvr>  )r   r  ra  r7  r   r]  r  r  r  rz  rN  r&  )r)  r0  r  biasr2  r1  r;  
transposedoutput_paddinggroupsZ	benchmarkZdeterministiccudnn_enabledZ
allow_tf32weight_sizekernel_shaperP  rH  r*  r!  r!  r"  _convolutionC	  sB    




r  zaten::_convolution_modec                 C  s   t |}z|dd  }	W n tk
r2   d }	Y nX |	d ksNtdd |	D rZtd|||g}
t |st |dkr|
| |dkrd}n|dkrd	}|dd  ||||d
}| j	d|
|}t |st |dkr| 	d||S |S d S )Nr  c                 s  s   | ]}|d kV  qd S r  r!  r-  r!  r!  r"  r  	  s     z$_convolution_mode.<locals>.<genexpr>r  r@  validZVALIDsameZ
SAME_UPPER)r>  r@  Z
auto_pad_sr  r  r  r>  )r  )
r   r  ra  r7  r   r]  r  r  r  r&  )r)  r0  r  r  r2  r1  r;  r  r  r  rP  rH  r*  r!  r!  r"  _convolution_mode	  sB    



r  zaten::convolutionc
           
      C  s"   t | |||||||||	d d d d S r  r  )
r)  r0  r  r  r2  r1  r;  r  r  r  r!  r!  r"  r?   	  s     zaten::conv1dc           	      C  s\   t |d}|dkr*t| |||||||S t |d}t| ||||||dd|d d d d S d S NrY  )r  r  r  Fr!  r   r`  r  r  	r)  r0  r  r  r2  r1  r;  r  Zstr_paddingr!  r!  r"  r;   	  s:    zaten::conv2dc           	      C  s\   t |d}|dkr*t| |||||||S t |d}t| ||||||dd|d d d d S d S r  r  r  r!  r!  r"  r<   
  s:    zaten::conv3dc           	      C  s\   t |d}|dkr*t| |||||||S t |d}t| ||||||dd|d d d d S d S r  r  r  r!  r!  r"  r=   :
  s:    zaten::conv_transpose1dc	           	      C  s"   t | ||||||d||d d d d S NTr  	r)  r0  r  r  r2  r1  r  r  r;  r!  r!  r"  r8   `
  s     zaten::conv_transpose2dc	           	      C  s"   t | ||||||d||d d d d S r  r  r  r!  r!  r"  r9   
  s     zaten::conv_transpose3dc	           	      C  s"   t | ||||||d||d d d d S r  r  r  r!  r!  r"  r:   
  s     zaten::batch_normc
                 C  s   t |d t rDt |||||gsDtjdk rDt dddd|S t | |||||\}}}}| j	d||||||d| |sdndd	}
|s|
S |
\}}}}}|
|  |
|  |d	|   |d	|   |S d S )
Nr&      ZBatchNormalizationr  zaAll input tensors must have the same `dtype`. Turn off Autocast or export using opset version 15.r@  r|  )	epsilon_fZ
momentum_fr  zbatch_norm_dead_output-)r   check_training_moderW  Zis_autocast_enabledZargs_have_same_dtyper   export_onnx_opset_versionrD  Z_batchnorm_helperr&  r'  re  ZsetDebugNameZ	debugName)r)  r0  r  r  running_meanrunning_vartrainingmomentumepsr  rj  resZnew_running_meanZnew_running_varZ
saved_meanZ	saved_varr!  r!  r"  r&   
  sV    	     
zaten::native_layer_normrS  z#Tuple[_C.Value, _C.Value, _C.Value])r)  r0  normalized_shaper  r  r  r  c              
   C  s  dd t t|ddD }t| d}t| |}| jdk rN| jd||d}	n$| d|| jd	tj|tjd
d}	t	| ||	}
t
j|
t
jjk}|rt
j|}| jd|
t
| d}
| jdk r| jdt| |
||d}n,| dt| |
|| jd	tj|tjd
d}t| | d||}| d|
|}|rZt
j|}| jd|t
| d}|d ks|t|s|t| ||}|d kst|st| ||}|r| jd|t
| d}| d|}n
t| |}||	|fS )Nc                 S  s   g | ]
}| qS r!  r!  r-  r!  r!  r"  r    s     z%native_layer_norm.<locals>.<listcomp>r   r         @   r  r  rU  rm  rV  r_  r`  r>  r^  r|  )r  rN  r   Z_generate_wrapped_numberr  r&  rW  r   longr   r   rd  re  r  ri  r   r   r  r   r   r   )r)  r0  r  r  r  r  r  Ztwo_cstZeps_cstr  	numeratorZis_type_halfZ	eps_dtypeZvariancedenominator
normalizedZinput_dtypeZrdenominatorr!  r!  r"  r   
  sf    

  

    
zaten::layer_norm)r)  r0  r  r  r  r  cudnn_enabler  c           	   	   C  s<   t  r | jd||||||dS t| |||||\}}}|S )Nrt   )Znormalized_shape_ieps_fZcudnn_enable_i)r   r  r  r   )	r)  r0  r  r  r  r  r  r  rC  r!  r!  r"  rt   H  s    	zaten::instance_normr   )r)  use_input_statsr  r  r  c
                 C  s,  t |d t |d}
|d ks*t |rl|
d kr>td|tjdg|
 tj	
| d}| jd|d}|d ks~t |r|
d krtd|tjdg|
 tj	
| d}| jd|d}|d kst |s|d kst |r| jd	||||d
S t |}| }|d }|d kr(td||d }d|d< || |d< t| || jdtj|gtjdd}t| || jdtj|gtjdd}t| || jdtj|gtjdd}t| || jdtj|gtjdd}| d|| jdt|d}t| |||||||||	
}t| || jdt|dS d S )Nrn   r@  zCUnsupported: ONNX export of instance_norm for unknown channel size.rS  rm  rU  rV  r  InstanceNormalizationr  r   zJUnsupported: ONNX export of instance_norm training for unknown batch size.ZReshape)r   r  r  r  r   r]  rW  r   r   rd  re  rn  r&  r  copyr   rr  r  r&   r  )r)  r0  r  r  r  r  r  r  r  r  channel_sizeweight_value
bias_value
input_sizeZinput_size_reshaper*  cweight_bias_Zrunning_mean_Zrunning_var_input_reshapedrj  r!  r!  r"  rn   c  s    

    zaten::unfoldc                   s   t  rjd ||dS t }z|  }W n tk
rJ   d }Y nX |d k	rtd||}t||d |} fddt||D }	t|}
ttd|
	
   fdd|	D }jd|d	 iS t d
dS d S )Nr  )Zdimension_iZsize_iZstep_ir   r@  c              	     s*   g | ]"\}}t j g|g|gd qS )r  r   r  )r  lowhi)	dimensionr)  r0  r!  r"  r    s       zunfold.<locals>.<listcomp>c              
     s(   g | ] }t jd |d gqS )r  r  )r   r  r&  r  )r  r)  r  r!  r"  r    s     r2  r3  ZUnfoldr.  )r2  )r   r  r  r  ra  r  rF  rN  r5  r  popr&  r  )r)  r0  r  r   stepr8  ZsizedimZlow_indicesZ
hi_indicesr   r  r  r!  )r  r)  r0  r  r"  r    s2    

  z	aten::eluc                 C  sJ   |r|dkrt dd|S |r4|dkr4t dd|S | jd|t |dS )NrS  r  zdoes not support scale in Eluinput_scalez#does not support input_scale in EluElur"  )r   r  r&  rE  )r)  r0  rH  r  r  r!  r!  r"  rI     s        z
aten::seluc                 C  s   |  d|S )NZSelur.  r/  r!  r!  r"  r     s    zaten::index_selectc                 C  s   t | |||S r  )r   _select_helper)r)  r:  rE   rm   r!  r!  r"  rl     s    zaten::index_putc                 C  s   t |rt |}n|g}t  rD|g| ||g }| jd| S t |d}t|dkrp|rlt| ||S |S t ddd| d S )Nrk   r  r   r  r?  )rk   )	r   r  r  r  r  r`  rN  r   r  )r)  r:  Zindices_list_valuevalues
accumulateZindices_listrP  r!  r!  r"  rk     s    
zaten::index_fillc           	      C  sr   t |d}t  r*| jd|||d|dS t | |||\}}t |}t ||}t| ||d }t| ||||S )Nr~  rj   Z
int_Scalar)r  r  )	r   r`  r  r  _index_fill_reshape_helperrF  r  rR   r   )	r)  r:  rE   rm   rX  	dim_valueexpanded_index_shapeexpanded_indexZexpanded_valuer!  r!  r"  rj   "  s(    	   
zaten::index_copyc                 C  sL   t |d}t  r(| jd||||dS t | |||\}}t| ||||S )Nr~  ri   r  )r   r`  r  r  r  r   )r)  r:  rE   rm   sourcer  r  r  r!  r!  r"  ri   :  s       zaten::bucketizec                 C  s   t jj}|rt jj}| jd| d|| d|dd}t|}|d k	sLttt	d|d }t
| t| |||d }	|rt| ||	}
nt| ||	}
| jd|
|d}tj| |dgddS )	Nr2  r-  r   r  r@  r_  r`  r&  )ra  rb  rc  r  r&  r   r  rz  r5  r  rR   r  r]   rb   r)  )r)  r:  Z
boundariesZ	out_int32rt  Zout_type	new_shapeZtensor_rankZunsqueeze_axesZexpanded_boundariescondZcond_outr!  r!  r"  r-   F  s$    "

zaten::type_asc                 C  sj   t |}t |}||kr(|d k	r(|S |d k	rD| jd|| dS t  rZ| d||S td|d S )Nr_  r`  r  zUnsupported: ONNX export of type_as for tensor of unknown dtype. Please check if the dtype of the parameter passed to the type_as function is correct.)r   r  r&  ri  r  r  r   r]  )r)  r:  r=  
self_dtypeZother_dtyper!  r!  r"  r  l  s     

zaten::cosine_similarityc           	      C  s   t  r| jd||||dS t j| t| |||gdd}t j| t| |||gdd}t j| t| |||gdd}t| t| t| ||| jdt	|gd}t
| ||S )NrA   )r  r  r   r&  rU  rV  )r   r  r  r)  r   r   r   r&  rW  r   rF   )	r)  x1x2rE   r  rB   Zx1_l2Zx2_l2Zdiv_tensr!  r!  r"  rA     s4     
   
   
    zaten::pairwise_distancec                 C  s   t |s | jdt|gd}t| | jdtjdgtjddt| ||}t j| t	| t
| |||dgt |dd}t	| ||S )NrU  rV  r@  rm  r  r~  r&  )r   rB  r&  rW  r   rF   rx  r   r)  r   r   r`  )r)  Zinput1Zinput2pr  r  Zinv_pZ	summationr!  r!  r"  r     s    


zaten::clonec                 C  s   |S r  r!  )r)  r0  Zunused_memory_formatr!  r!  r"  r4     s    z	aten::absc                 C  s   |  d|S )NAbsr.  r}  r!  r!  r"  r     s    z	aten::logc                 C  s   |  d|S )NLogr.  r}  r!  r!  r"  r     s    zaten::log1pc              	   C  s    t | t| ttd||S )Nr@  )r   r   r   r  rW  r   r}  r!  r!  r"  r     s    zaten::log10c              	   C  s*   d}|  dt| || j dt|gdS )NgUk@r^  rU  rV  r&  r   rW  r   )r)  r:  Z_ln10r!  r!  r"  r     s    z	aten::powc                 C  sb   t j|}t|s2t jj}| jd|| d}t|sP| jd|| d}| d||}|S )Nr_  r`  r  )r   rd  re  r   rg  rh  r&  ri  )r)  r:  exponentZf_dtyper   r!  r!  r"  r     s    

zaten::clampc              	   C  s~   t |rt| ||S t |r,t| ||S t |rft |rft j| d|t |dt |dddS t| t| |||S d S )NCliprR     min_fmax_fr  )r   r  r1   r2   r  r  r`  )r)  r:  r   r   r!  r!  r"  r3     s    



	zaten::clamp_minc                 C  s^   t |r&t j| d|t |dddS tj|}| jd|| d}t j| d||ddS d S )	Nr  rR  r  )r   r  r_  r`  Maxr  	r   r  r  r`  r   rd  re  r&  ri  )r)  r:  r   rn  r!  r!  r"  r2     s"    
   
     zaten::clamp_maxc                 C  s^   t |r&t j| d|t |dddS tj|}| jd|| d}t j| d||ddS d S )	Nr  rR  r  )r  r  r_  r`  ZMinr  r  )r)  r:  r   rn  r!  r!  r"  r1     s"    
   
     z	aten::maxc                 C  s   t | |||S r  )r   Z_max_helperr)  r:  dim_or_yr  r!  r!  r"  r     s    zaten::maximumc                 C  s   t | ||dS N)r  )r   r  r!  r!  r"  r     s    z	aten::minc                 C  s   t | |||S r  )r   Z_min_helperr  r!  r!  r"  r   &  s    zaten::minimumc                 C  s   t | ||dS r  )r   r  r!  r!  r"  r   -  s    z
aten::amaxc                 C  s   | j d|||dS )Nr%  r&  r.  r)  r:  rE   r  r!  r!  r"  r   4  s    z
aten::aminc                 C  s   | j d|||dS )N	ReduceMinr&  r.  r  r!  r!  r"  r   <  s    zaten::aminmaxc                 C  sJ   d|i}t |s*t |dd}|g|d< | jd|f|| jd|f|fS )Nr'  r~  rE   r  r  r%  )r   r  r  r&  )r)  r:  rE   r  Zreduce_kwargsr!  r!  r"  r   D  s    

 z	aten::expc                 C  s   |  d|S )Nr(  r.  r}  r!  r!  r"  rP   S  s    zaten::dropout_zaten::dropoutc                 C  s.   t |d |s|S | jd||dd\}}|S )NrH   ZDropoutr  )Zratio_fr  )r   r  r&  )r)  r0  r  trainrI  rC  r!  r!  r"  rH   Y  s
    zaten::alpha_dropout_zaten::feature_alpha_dropout_zaten::feature_dropout_zaten::feature_alpha_dropoutzaten::alpha_dropoutzaten::feature_dropoutc                   s$   t dddtj fdd}|S )NrQ  r  r  c                   s   |rt  d|S |S )Nztraining mode)r   r  )r)  r0  r  r	  r  r!  r"  feature_dropout  s    z-_unsupported_dropout.<locals>.feature_dropoutr   rP  r   rQ  )r  r
  r!  r  r"  _unsupported_dropoutf  s    r  z
aten::normc                 C  sx   |dkrt d}n |dkr(t d}ntd||| |||d}|d k	rtt |dd}| jd	|t| d
}|S )Nr@  ZReduceL1r  ZReduceL2z)ONNX export only p-norms with p of 1 or 2)rE   r  r~  rn  r_  r`  )	r   Z_reduce_op_symbolic_helperr   r]  r  r&  r   rd  ri  )r)  r:  r  rE   r  rn  rR  r  r!  r!  r"  r     s     zaten::conv_tbcc              	   C  s~   t  r| jd||||dS | jd|dddgd}| jd|dddgd}t| |||dg|gdgd}| jd|dddgdS d S )Nr7   )Zpad_ir  r@  r  r   r  )r   r  r  r&  r;   )r)  r0  r  r  r   convr!  r!  r"  r7     s    zaten::_uniquec                 C  s,   t  r| jd|||ddS t d|S d S )N_uniquer  )sorted_ireturn_inverse_ir  )r   r  r  r  )r)  r0  sortedreturn_inverser!  r!  r"  r    s    r  zaten::_unique2c                 C  s2   t  r| jd||||ddS t ddd| d S )N_unique2r:  )r  r  Zreturn_counts_ir  r  r?  )r   r  r  r  )r)  r0  r  r  Zreturn_countsr!  r!  r"  r    s    	r  zaten::_cast_Bytez2.0z
the futurez8Avoid using this function and create a Cast node insteadc                 C  s   | j d|tjjdS r  )r&  ra  rb  r  r)  r0  Znon_blockingr!  r!  r"  
_cast_Byte  s    r  zaten::_cast_Charc                 C  s   | j d|tjjdS r  )r&  ra  rb  r  r  r!  r!  r"  
_cast_Char  s    r  zaten::_cast_Shortc                 C  s   | j d|tjjdS r  )r&  ra  rb  r  r  r!  r!  r"  _cast_Short  s    r  zaten::_cast_Intc                 C  s   | j d|tjjdS r  )r&  ra  rb  r  r  r!  r!  r"  	_cast_Int  s    r  zaten::_cast_Longc                 C  s   | j d|tjjdS r  )r&  ra  rb  rc  r  r!  r!  r"  
_cast_Long  s    r  zaten::_cast_Halfc                 C  s   | j d|tjjdS r  )r&  ra  rb  ZFLOAT16r  r!  r!  r"  
_cast_Half  s    r  zaten::_cast_Floatc                 C  s   | j d|tjjdS r  )r&  ra  rb  rh  r  r!  r!  r"  _cast_Float  s    r  zaten::_cast_Doublec                 C  s   | j d|tjjdS r  )r&  ra  rb  r{  r  r!  r!  r"  _cast_Double  s    r  zaten::_cast_Boolc                 C  s   | j d|tjjdS r  r  r  r!  r!  r"  
_cast_Bool)  s    r  zaten::emptyc                 C  s   t | |||||S r  )r  )r)  r8  rn  layoutdevice
pin_memorymemory_formatr!  r!  r"  rM   4  s    zaten::empty_likec                 C  s   t | |||||S r  )r  )r)  r0  rn  r  r  r   r!  r!  r!  r"  rL   C  s    zaten::new_emptyc                 C  s2   t |}t |r |d k	r |}t| |||||S r  )r   r  r  rM   r)  r:  r8  rn  r  r  r   r  r!  r!  r"  r   R  s    
zaten::scalar_tensorc                 G  s<   t |dd}|d krtjj}| jd|t| d}|S )Nr~  rn  r_  r`  )r   r  r   rd  rh  r&  ri  )r)  Zscalarrn  optionsr!  r!  r"  r   ]  s
    zaten::tensorc                 C  s  t |dd}t |r|d kr6tjt |d }t }t |D ]L}| jdt	
dgd}t | ||}| jd|t| d}|| qF| jd|d
diS |d krtj|}t |rt |st |r| jd|ddd}| jd|t| dS )Nr~  rn  r   rU  r@  rV  r_  r`  r2  r3  ZConcatFromSequence)r3  Z
new_axis_i)r2  )r   r  r  r   rd  re  r  r5  r&  rW  r  r8  ri  r  Z_is_listrC  Z_is_scalar_list)r)  datarn  r  requires_gradrk  r   Zshape_referencer!  r!  r"  r   g  s,    

zaten::as_tensorc                 C  s   t | |||S r  )r   )r)  r$  rn  r  r!  r!  r"  r!     s    zaten::zerosc                 C  sz   |d krt jj}n
t |}t|d}t|trZt|dkrZ| jdt	
g t	jd}| jd|t	j
dg| ddS )Nr  r   rU  rV  ConstantOfShaperm  r   rd  rh  r   r  r4  r5  rN  r&  rW  r   r   rr  rn  r)  r8  rn  r  r  r   rk  sizes_r!  r!  r"  r    s    

zaten::zeros_likec           	      C  sT   |  d|}t|r*tj|tjj}n
t|}| j d|tjdg|	 ddS )Nr-  r&  r   rm  rV  
r&  r   r  r   rd  re  rh  rW  r   rn  	r)  r0  rn  r  r  r   r!  r6  rk  r!  r!  r"  r    s    
 
zaten::new_zerosc                 C  s2   t |}t |r |d k	r |}t| |||||S r  )r   r  r  r  r"  r!  r!  r"  r     s    
z
aten::zeroc                 C  s   t |}t| ||S r  )r   r  r  )r)  r:  r  r!  r!  r"  r    s    
z
aten::onesc                 C  sz   |d krt jj}n
t |}t|d}t|trZt|dkrZ| jdt	
g t	jd}| jd|t	j
dg| ddS )Nr  r   rU  rV  r&  r@  rm  r'  r(  r!  r!  r"  r     s    

zaten::ones_likec           	      C  sT   |  d|}t|r*tj|tjj}n
t|}| j d|tjdg|	 ddS )Nr-  r&  r@  rm  rV  r*  r+  r!  r!  r"  r     s    
 
zaten::new_onesc                 C  s2   t |}t |r |d k	r |}t| |||||S r  )r   r  r  r   r"  r!  r!  r"  r     s    
z
aten::fullc              	   C  s   t |d}t |rX|d kr&tjjn|}t| ||||}t| ||| jdt	
ddS t |dd}|d krxtjj}	n
t|}	t |d}
t|
trt|
dkr| jdt	
g t	jd}| jd	||d|	 dS d S )
Nr   rU  r@  rV  r~  rn  r  r   r&  )r   r  rB  r   rd  rh  r  r   r&  rW  r   r  r4  r5  rN  r   rr  r  rn  )r)  r8  rX  rn  r  r  r   const_valuetmprk  r)  r!  r!  r"  r[     s"    


zaten::full_likec              	   C  s   t |d}t |dd}|d kr6tj|tjj}n
t|}t |rt| ||||}	| j	d||
 d}t| |	|| j	dtddS | 	d	|}
| j	d
|
tj|g| ddS d S )NrR  r~  rn  r_  r`  rU  r@  rV  r-  r&  rm  )r   r  r  r   rd  re  rh  rB  r  r&  ri  r   rW  r   rn  )r)  r0  
fill_valuern  r  r  r   r!  rk  r-  r6  r!  r!  r"  rZ     s$     

zaten::new_fullc           	      C  s4   t |}t |r |d k	r |}t| ||||||S r  )r   r  r  r[   )	r)  r:  r   r.  rn  r  r  r   r  r!  r!  r"  r   <  s    
	aten::eyec                 G  s   t |dkrX|\}}}}}t| |dg}| jd||dd}t| ||||}	| d|	S t |dkr|\}}
}}}}| jdt| |dgt| |
dgdd}t| ||||}	| d|	S tddt | d	S )
Nr|  r   r2  r  ZEyeLiker  r/  with 
 arguments)rN  r   r  r&  r  r  )r)  rP  r*  rn  r  r  r   r  r6  r   mr!  r!  r"  rS   N  s"    aten::slicec                 G  s2  t |dkrr|\}}}}t|d}|dkr:td||  dkoXt| t	j
}|  dkoxt| t	j
}|  dk}	|  dk}
|s|	r|s|
r|  dkrtjtjjkrtd|nBt| |dg}t| |dg}t| |dg}| d	||||S nT|r&dn
t|d}|r>tjn
t|d}t|d}tj| ||g|g|gd
S nt |dkr|\}}}d}|  dkot| t	j
}|  dkot| t	j
}|rdn
t|d}|rtjn
t|d}tj| ||g|g|gd
S tddt | dS )Nr  r~  r@  z"step!=1 is currently not supportedr%  r  zUnsupported: ONNX export of Slice with dynamic inputs. DynamicSlice is a deprecated experimental op. Please use statically allocated variables or export to a higher opset version.r   ZDynamicSlicer  r:  r3  r0  r1  )rN  r   r`  r   r]  r  r  r4  re  r	   ZNoneTyper   operator_export_typera  ZOperatorExportTypesZONNXr  r&  r
   r  r  r  )r)  r:  rP  rE   rq  rr  r  Zis_start_noneZis_end_noneZis_start_onnx_constZis_end_onnx_constZstart_unsqueezedZend_unsqueezedZdim_unsqueezedr!  r!  r"  r   g  s      

    
  

    zaten::hardtanhr)  r:  Zmin_valZmax_valc                 C  s   t j| d|||ddS )Nr  r  r  r  r5  r!  r!  r"  rg     s         zaten::hardswishc                 C  s   t | |}| d||S r  )re   r&  )r)  r:  hsr!  r!  r"  rf     s    
zaten::hardsigmoidc                 C  s   | j d|ddS )NHardSigmoidgUUUUUU?r"  r.  r}  r!  r!  r"  re     s    zaten::tanhshrinkc                 C  s   |  d|t| |S )NrI  )r&  r   r}  r!  r!  r"  r     s    zaten::hardshrinkc                 C  sx   t j|t jj}| jdtj|| dd}t| t	| ||t
| |t| |}| d||| jdtjd| ddS NrU  rm  rV  r  r   )r   rd  re  rh  r&  rW  r   rn  r   rb   r   r   )r)  r:  lambdrk  lambd_opr  r!  r!  r"  rd     s$     "zaten::softshrinkc           	      C  s   t j|t jj}| jdtj|| dd}t| ||}| d|t	| ||| jdtjd| dd}t
| |t| |}| d|t| ||| jdtjd| dd}t| ||S r8  )r   rd  re  rh  r&  rW  r   rn  rb   r   r   r   r   )	r)  r:  r9  rk  r:  Zgt_condZgt_outZlt_condZlt_outr!  r!  r"  r     s:     
	
	zaten::aliasc                 C  s   |S r  r!  r}  r!  r!  r"  r     s    zaten::unsqueezec                 C  s~   |dk rlt |}|dk	r^tdt| d d d t|| d  d d	  || d }nt d
d|S t j| ||gdS )zbImplement unsqueezing a pytorch tensor in ONNX by inserting a new dimension at the specified `dim`r   Nz)ONNX export unsqueeze with negative axis r	  r
  r  r@  r  r  r  r  r  )r   r  r  r  r  r  r  r  r!  r!  r"  r    s6    

  z
aten::sortc                 C  sp   |d k	rt dd| t |}z|| }W n tk
rD   d }Y nX |d kr\t dd|S | jd|||ddS )NZSortz'Out parameter is not supported for sortr.  TopKr  Zk_ir3  r  )r   r  r  ra  r&  )r)  r:  rE   Z	decendingrj  Z
self_sizesr  r!  r!  r"  r   1  s      

zaten::numelc                 C  s   t | |S r  )r   Z_numel_helperr}  r!  r!  r"  r   H  s    z
aten::topkc                 C  s<   |d k	rt dd| |s(t dd| | jd|||ddS )Nr;  z'Out parameter is not supported for topkzAscending TopK is not supportedr  r<  )r   r  r&  )r)  r:  rb  rE   largestr  rj  r!  r!  r"  r   N  s      zprim::convert_element_typec                 G  s,   t |d dd}| jd|t| dS )Nr   r~  rn  r_  r`  )r   r  r&  r   rd  ri  )r)  r:  rP  rn  r!  r!  r"  r>   ]  s    zaten::toc                 G  s  t jdd }||r|S t|dkr|d }t|d r|d   dkrt|d  d}t|t	j
rt|jdkr| }t|}n|}t|st|t	j
rtj|d }| jd|| dS | jd|t| dS nt|d	kr$t|d
 dd}| jd|t| dS t|dkr^t|d dd}| jd|t| dS t|dkrt|d dd}| jd|t| dS td|S )Nc                 S  s   t | dkrL| d   dkpJ| d  tj pJt| d  tj	S t | dkrrt
| d dd}|d kS t | dkrt
| d dd}|d kS d	S )
Nr  r   prim::devicer|  r@  r~  rn  )r     F)rN  r  r  re  isSubtypeOfr	   ListTypeofIntsr4  r  r   r  )rP  rn  r!  r!  r"  is_aten_to_device_onlyg  s    z"to.<locals>.is_aten_to_device_onlyr  r   r  rX  r_  r`  r|  r@  r~  rn  r  r?  zUnknown aten::to signature)r   rQ  rN  r   rB  r  r  r   r4  rW  r  r6  rr   r  r   rd  re  r&  ri  r  r  )r)  r:  rP  rC  rn  Ztvalr!  r!  r"  r   d  sD    

zaten::repeatc                 C  s0   t jj}t| ||}| d||}| d||S )Nr  ZTile)r   rd  rc  r   r&  )r)  r:  repeatsrn  Zshape_r!  r!  r"  r     s    zaten::repeat_interleavec              
   C  s  t |}t |}t |}|d kr2td||d krFtd||d krZtd|t |rt | || jdt	dgd}tj	dtj
d}n
t |}|dk r|t|7 }| }t|D ] \}	}
|
d krd	\||	< ||	< q|dks|d
kr8|d d
kr8|| dkr(t dddd|S t | |||S |d
kr|| dkrbt dddd|S |d d krt dddd|S |d || kstd|d }ntd|t }t | ||d}t | |||}d\||< ||< t|D ]\}	}t| ||	 |d
 }| jdt|d |d
  d|| jdt||d
 d  dg}| jd|ddi}t| ||d }t j| || jdt|ddd}|| q| jd|d|iS )NzGUnsupported: ONNX export of repeat_interleave for unknown repeats rank.zGUnsupported: ONNX export of repeat_interleave for unknown repeats size.zEUnsupported: ONNX export of repeat_interleave for unknown input size.rU  r  rV  r   rm  )r   r  r@  r   r     z3Unsupported along dimension with unknown input sizez*Unsupported for cases with dynamic repeatsz2repeats must have the same size as input along dimz%repeats must be 0-dim or 1-dim tensor)r  r@  r2  r3  Z	allowzero)r2  )r2  )r   r  r  r   r]  r  r8  r&  rW  r   rr  rF  rN  r  	enumeraterD  Z-_repeat_interleave_single_value_repeat_helperrz  r5  Z_repeat_interleave_split_helperr  r  rR   r  )r)  r:  rD  rE   r_  Zrepeats_dimZrepeats_sizesZinput_sizesZinput_sizes_tempro  r  ZrepsZfinal_splitsZr_splitsZi_splitsZr_splitZi_splitZr_concatr!  r!  r"  r     s    



  
"   

zaten::pixel_shufflec           	      C  s  t |}t|dkr$t dd|S tdd |dd  D rt j| t | |ddg| jd	t	d
d||d
d
gdd
d}| jd|d
dddddgd}t j| || jd	t	d
d
ddd
d
gdd
d}t j| || jd	t	d
d
d
d
ddgdd
d}t 
| |ddgS |d | | }t j| || jd	t	d||||d |d gdd
d}| jd|d
dddddgd}t j| || jd	t	d||d | |d | gdd
dS d S )Nr  r   only support 4d inputc                 s  s   | ]}|d kV  qd S r  r!  r-  r!  r!  r"  r  ,  s     z pixel_shuffle.<locals>.<genexpr>r@  r  r:  rU  r   r  rV  rF  r  r|  r  r   r  rN  r  r7  r8  r  r&  rW  r   r  )	r)  r:  Zupscale_factorr  
after_viewafter_transpose	reshape_h	reshape_woutput_channelr!  r!  r"  r   #  s    
  	

zaten::pixel_unshufflec           
      C  s  t |}t|dkr$t dd|S tdd |dd  D rt j| t | |dg| jdt	d	d	d
|d	gdd	d}t j| || jdt	d	d	d	d	d
|gdd	d}| jd|d	dddddgd}t j| || jdt	d	d
ddd	d	gdd	d}t 
| |ddgS |d | | }t j| || jdt	d
|d |d | ||d | |gdd	d}	| jd|	d	dddddgd}t j| || jdt	d
||d | |d | gdd	dS d S )Nr  r   rH  c                 s  s   | ]}|d kV  qd S r  r!  r-  r!  r!  r"  r  u  s     z"pixel_unshuffle.<locals>.<genexpr>r@  r:  rU  r   r  rV  rF  r  r|  r  r  rI  )
r)  r:  Zdownscale_factorr  rL  rM  rK  Zfinal_reshaperN  rJ  r!  r!  r"  r   l  s|    
  



c           *        s  t d d d d d  dddd	d
ddddddg}ttdd |D |}|rXdnddkrt  d|	  krtdd|S t  d|	  kst fddtdt D |
rj	d|dddgd}|r|rtdd|S 
dr|dd    }d d }t|dd krFtdd|S |	 }|}g }dkshd krn|}ndkr|\}}g }|d krtn|}d krd!d"d#gndkrd"d$d%gtjd&d' tjfd(d)}tjfd*d+}tjfd,d-}tD ]R}|rndkrL||\}}}n||\}}t}||d f}ndkr|d| \}} }!|d| d \}"}#}$j	d.|!|$dd/}n,|d| \}} |d| d \}"}#t}j	d.||"dd/}j	d.| |#dd/}d| d| d f}|||||g}%|%||f|  dkrX|%||f|  |rbi nd0d1i}&dkr|	r||g}'n|g}'j	d;|%d|'d2|&\}}(nVd kr؈j	d<|%ddd3|&\}}(n*dkrj	d=|%d4d5|&\}}(})|	rJj	d|dddd4gd}tj|j	d6tddd7gd8dd9}nt|dg}||( dkr&||) q&|
rj	d|dddgd}dkr|(nj	d>|d:di}dksΈd kr||fS dkrdkr|)nj	d?|d:di}|||fS d S )@NzVExporting a model to ONNX with a batch_size other than 1, with a variable length with z can cause an error z9when running the ONNX model with a different batch size. z4Make sure to save the model with a batch size of 1, z=or define the initial states (h0/c0) as inputs of the model. r  r  r  ZAffiner!  ZThresholdedReluZ
ScaledTanhr7  r  ZSoftsignr  c                 S  s   g | ]}|  qS r!  )lower)r  Zact_funr!  r!  r"  r    s     z _generic_rnn.<locals>.<listcomp>r  r  LSTMr@  zLSTMs with projectionsc                   s   g | ]} ||  qS r!  r!  r-  )all_weightsweights_per_layerr!  r"  r    s   r   r  r  zRNN/GRU/LSTMzdropout in training modeRNNzunknown hidden sizeGRU)r@  r  )r   r@  )r  r:  )r:  r  )r@  r:  c                   s*    fdd|D } j d|ddiS )Nc              	     s2   g | ]*\}}t j d g| g| gdqS )r   r  r  )r  xyr)  r*  wr!  r"  r    s   z8_generic_rnn.<locals>.reform_weights.<locals>.<listcomp>r2  r3  r   )r2  r.  )r)  rX  r*  Z	intervalsZslicesr!  rW  r"  reform_weights  s    z$_generic_rnn.<locals>.reform_weightsc                   s`   |  }dkr|\}}n,dks*dkrF fdd|D \}}t  fdd||fD S )NrS  rT  rP  c                 3  s   | ]} |V  qd S r  r!  r  rX  r)  hidden_sizereform_permutationrY  r!  r"  r    s    zB_generic_rnn.<locals>.transform_weights_no_bias.<locals>.<genexpr>c                 3  s   | ]}t  |d gV  qdS r  r  r  rU  r+  r!  r"  r     s    )rE  )layer_indexweights	weight_ih	weight_hhr)  r\  layer_weightsr]  rY  variantr!  r"  transform_weights_no_bias  s    

z/_generic_rnn.<locals>.transform_weights_no_biasc                   s|   |  }dkr|\}}}}n0dks.dkrN fdd|D \}}}} j d||dd}t fd	d|||fD S )
NrS  rT  rP  c                 3  s   | ]} |V  qd S r  r!  rZ  r[  r!  r"  r  *  s    z:_generic_rnn.<locals>.transform_weights.<locals>.<genexpr>r2  r   r  c                 3  s   | ]}t  |d gV  qdS r  r  r^  r+  r!  r"  r  .  s   )r&  rE  )r_  r`  ra  rb  Zbias_ihZbias_hhbias_concatrc  r!  r"  transform_weights$  s    z'_generic_rnn.<locals>.transform_weightsc                   s&   dkr| S t j | dg|g|gdS )Nr@  r   r  r  )rU  rq  rr  )r)  
num_layersr!  r"  retrieve_state3  s        z$_generic_rnn.<locals>.retrieve_stater2  r  Zdirection_sbidirectional)r  hidden_size_iZactivations_s)r  rl  Zlinear_before_reset_ir:  )r  rl  rU  r  rV  rF  r3  )rS  )rT  )rP  )r2  )r2  )r  r  dictrF  rN  r   r  rz  r  r&  
startswithrO  r  r  r   rQ  r  r8  rW  r  r  )*r)  re  r0  Zinitial_statesrQ  
has_biasesri  rH   r	  rk  batch_firstbatch_sizesZonnxActivationsZvariantToOnnxActivationMapZnonlinearityw_hhZunidirectionalZprev_outputh_outsZh0Zc0c_outsZsequence_lensrf  rh  rj  r~  ra  rb  rg  Zstate_indicesZweight_ih_fZweight_hh_fZbias_fZweight_ih_bZweight_hh_bZbias_binputsextra_kwargsZ
activationZh_outZc_outr!  )	rQ  r)  r\  rd  ri  r]  rY  re  rR  r"  _generic_rnn  s:     
  




	




 
 
  

"
"rw  c
                 C  s2   t |t | }
}t| d||
|||||||	S )NrP  r   r  rw  )r)  r0  hidden_vweight_vro  ri  rH   r	  rk  rp  hiddenr  r!  r!  r"  
_lstm_full  s$    r|  c
                 C  s4   t |t | }
}t| d||
||||||	|dS )NrP  rq  rx  )r)  r0  rq  ry  rz  ro  ri  rH   r	  rk  r{  r  r!  r!  r"  _lstm_packed  s$    r~  z
aten::lstmc                 G  s.   t |d rt| f| S t| f| S d S Nr:  )r   rC  r~  r|  r)  rP  r!  r!  r"  r     s    zaten::lstm_cellc                   s   t  |dg}t |} fdd|D }t |rB||||fn||f}t |rXdnd}	t d||||	dddddd\}
}}t  |dgt  |dgfS )	Nr   c                   s   g | ]}t  |d gqS rY  r  r^  r+  r!  r"  r    s     zlstm_cell.<locals>.<listcomp>TFrP  r@  )ri  rH   r	  rk  rp  )r   r  r  Z
_is_tensorrw  r  )r)  r:  r{  Zw_ihrr  Zb_ihZb_hhr0  r  ro  rC  rs  rt  r!  r+  r"  r     s4    
  z	aten::grurT  Zgruzaten::rnn_tanhZRNN_TANHZrnn_tanhzaten::rnn_reluZRNN_RELUZrnn_relur  c                   sd   t ddddddddd	tjfdd t ddddddddd	fdd fdd	}|S )
NrQ  r~  rR  c
                   s&   t |}
t|  |||
||||||	S r  rx  )r)  r0  r{  rz  ro  ri  rH   r	  rk  rp  r  r  r!  r"  	_rnn_full	  s    
z"_one_hidden_rnn.<locals>._rnn_fullc
                   s(   t |}
t|  |||
|||||	|dS )Nr}  rx  )r)  r0  rq  r{  rz  ro  ri  rH   r	  rk  r  r  r!  r"  _rnn_packed&  s    
z$_one_hidden_rnn.<locals>._rnn_packedc                   s.   t |d r| f| S  | f| S d S r  )r   rC  r  )r  r  r!  r"  symbolicB  s    z!_one_hidden_rnn.<locals>.symbolicr  )r  r  r!  )r  r  r  r"  _one_hidden_rnn  s    r  zaten::_dim_arangec                 C  sX   |  d|}| j d|| j dt|ddd}t rB|  d|S t| |dd d d S d S )	Nr-  r  rU  rV  r   r  z_caffe2::Ranger  )r&  rW  r   r   r  r   )r)  likerE   Z
like_shapestopr!  r!  r"  _dim_arangeK  s       r  zaten::detachc                 C  s   |S r  r!  r/  r!  r!  r"  rD   Z  s    zaten::contiguousc                 C  s   |dkrt d||S )Nr  z-onnx memory_format support is not implemented)r   r]  )r)  r0  r!  r!  r!  r"  r6   a  s     zaten::_pack_padded_sequencec                 C  s|   |r| j d|dddgd}| tjj s<td|t	j
|t	j
jt	j
jkrj| j d|tjjd}| j d	||dd
S )Nr  r@  r   r  r  z*'lengths' must be a Tensor for ONNX exportr_  r`  zprim::PackPaddedr  )r&  re  r@  rW  r	   Z
TensorTypegetr   r]  r   rd  re  rf  r  ra  rb  r  )r)  r0  lengthsrp  r!  r!  r"  _pack_padded_sequencel  s       r  zaten::_pad_packed_sequencec                 C  s8   | j d||dd\}}|r0| j d|dddgd}||fS )Nzprim::PadPackedr  r  r  r@  r   r  r.  )r)  r$  rq  rp  Zpadding_valuetotal_lengthr  r!  r!  r"  _pad_packed_sequence  s    r  zaten::randintc                 G  s  t |dd}t |dd}t |dd}|d kr<tjj}n
t|}|d krZt d||d krnt d|t |d}	t |	r| jd|t	j
dgt	jd	d
}
| jd|
||d}n| jd|	||d}tjj}| jd|| d}||kr| jd|| d}|S )Nr~  rn  r  highr   r  r&  r   rm  rV  RandomUniformLikelow_fhigh_fRandomUniform)shape_ir  r  r_  r`  )r   r  r   rd  rc  r  r  rB  r&  rW  r   rx  ri  )r)  r  r  shapesrn  r#  low_ihigh_irk  r6  shape_constr   	int_dtyper   r!  r!  r"  r     sD    



zaten::randint_likec                 G  s   t |dd}t |dd}t |dd}|d kr<tjj}n
t|}|d krZt d||d krnt d|| jd|||d}	tjj}
| jd|	|
 d	}|
|kr| jd|| d	}|S )
Nr~  rn  r  r  r   r  r  r_  r`  )r   r  r   rd  rc  r  r&  ri  )r)  r:  r  r  rn  r#  r  r  rk  r   r  r   r!  r!  r"  r     s*    

zaten::randnc                 G  s   t |dd}|d kr tjj}n
t|}t |d}t |rr| jd|tj	dgtj
dd}| jd|| d	S | jd
|| dS )Nr~  rn  r  r&  r   rm  rV  RandomNormalLikedtype_iZRandomNormalr  r  r   r  r   rd  rh  r  rB  r&  rW  r   rx  ri  r)  r  rn  r#  rk  r6  r  r!  r!  r"  r     s*    


z
aten::randc                 G  s   t |dd}|d kr tjj}n
t|}t |d}t |rr| jd|tj	dgtj
dd}| jd|| d	S | jd
|| dS )Nr~  rn  r  r&  r   rm  rV  r  r  r  r  r  r  r!  r!  r"  r     s*    


zaten::randn_likec                 C  sH   t |dd}|d kr*tj|tjj}n
t|}| jd|| dS )Nr~  rn  r  r  r   r  r   rd  re  rh  r&  ri  )r)  r:  rn  r  r  r   r!  rk  r!  r!  r"  r     s     
zaten::rand_likec                 C  sB   t |dd}|d kr(tj|tjj}| jd|t| dS )Nr~  rn  r  r  r  )r)  r:  rn  r  r  r   r!  r!  r!  r"  r   /  s       zaten::rreluc                 C  s@   |s || d }| j d||dS | j d|||d}|  d||S )Nr  r!  r"  r  )r  r  r  r.  )r)  r0  rO  upperr  r  r  r  r!  r!  r"  r   D  s
    zaten::bernoullic           	      C  s   |d k	r t |s t dd| |d k	r@t |s@t dd| tj|tjj}|tjjkrlt dd|S | jd|dd| d}|d k	rt |s|n|}| d	||}| jd
|| dS )NZ	Bernoulliz,out parameter is not supported for bernoulliz(generator is not supported for bernoulliinput dtype not accessibler  rS  r  )r  r  r  r  r_  r`  )	r   r  r  r   rd  re  rf  r&  ri  )	r)  r0  r  r  rj  rn  ZrandsZprobr[  r!  r!  r"  r'   O  s@           zaten::log_sigmoidc                 C  s   |  d|}|  d|S )Nr  r  r.  )r)  r0  r  r!  r!  r"  r   o  s    z	aten::erfc                 C  s   |  d|S )NErfr.  r/  r!  r!  r"  rO   w  s    zaten::flattenc                 C  s   t |}|d kr t dd|S |dkr8t | |dgS |dkrL| d|S |dk r\|| }|dkr||d kr| jd||dS |dkr||d kr| jd||d dS t | ||||S )	NrE   r  r   r@  r  Flattenr  r  )r   r  r  r8  r&  Z_flatten_helper)r)  r0  Z	start_dimZend_dimrE   r!  r!  r"  rU   ~  s$    
zaten::nonzeroc                 C  s   t | | d|S )z/Emitted from `torch.nonzero(x, as_tuple=False)`ZNonZero)r   r&  r/  r!  r!  r"  r     s    zaten::nonzero_numpyc                 C  s   t | t| |d|dS )Nr@  )r  )r  r   )r)  r0  r  r!  r!  r"  r     s    zaten::isnanc                 C  s   |  d|}|S )NZIsNaNr.  )r)  r0  r[  r!  r!  r"  rq     s    z	aten::anyc              	   G  s   t |dkr|d }d\}}n6|\}}}t|d}dd |dD }t|d}| jd	|tjjd
}tj| |||d}t	| || jdt
jdt
jddS )Nr@  r   rM  r   c                 S  s   g | ]}t |qS r!  r5  )r  r  r!  r!  r"  r    s     z_any.<locals>.<listcomp>r  r~  r_  r`  r&  rU  rm  rV  )rN  r   r`  r  r&  ra  rb  rc  r)  rb   rW  r   r  )r)  rP  r0  rE   r  Z	input_sumr!  r!  r"  _any  s    

   r  z	aten::allc              	   G  sP   |  d|d }t|dkr.|  dt| |S |  dt| ||d |d S d S )Nrp  r   r@  r  )r&  rN  r  )r)  rP  r0  r!  r!  r"  _all  s    r  zaten::narrowc                 C  s   t j| ||g|g|| gdS )Nr  r  )r)  r0  rE   rq  lengthr!  r!  r"  r     s        zaten::argmaxztorch._C.Valuer)  r0  rE   r  c                 C  s   t | |||dS )NZArgMaxr   Z_argmin_argmax_helperr  r!  r!  r"  r     s    	zaten::argminc                 C  s   t | |||dS )NZArgMinr  r  r!  r!  r"  r     s    	zaten::scatterc                 C  s   t j|t jj}t|}t|r:| jd||||dS t j|}||krb| jd|| d}| jd||t	| |||dS d S )NZScatterr  r_  r`  )
r   rd  re  rf  r   rF  rB  r&  ri  rQ   )r)  r:  rE   rm   srcZsrc_typer  r!  r!  r"  r     s     

zaten::scatter_addc                 C  sz   t |}|d kr t dd|S t j|dd}|rP| jdtj|| dd}nt| ||}t 	| ||||}t
| ||S )Nr   r  F)Zallow_nonstaticrU  rm  rV  )r   r  r  r  r&  rW  r  rn  r  Z_scatter_helperr   )r)  r:  rE   rm   r  rk  r8  Zto_addr!  r!  r"  r     s    
  z
aten::log2c              	   C  s(   d}|  dt| || j dt|dS )Ng9B.?r^  rU  rV  r  )r)  r:  Z_ln2r!  r!  r"  r     s    zaten::is_floating_pointc                 C  s6   t |r | jdtdgdS | jdtdgdS NrU  r@  rV  r   )r   rg  r&  rW  
BoolTensorr}  r!  r!  r"  ro   !  s    
zaten::__is_c                 C  sL   t |r@t |r*| jdtdgdS | jdtdgdS t| ||S r  )r   r  r&  rW  r  rN   rL  r!  r!  r"  __is_)  s
    

r  zaten::__isnot_c                 C  s   t | ||S r  )r  rL  r!  r!  r"  __isnot_3  s    r  zaten::one_hotc                 C  sn   | j dtddgd}tj|tjjtjjtjjtjj	tjj
hkrZ| j d|tjjd}| j d|||dd	S )
NrU  r   r@  rV  r_  r`  OneHotr  r  )r&  rW  r  r   rd  re  rf  r  r  r  r  ra  rb  rc  )r)  r:  Znum_classesr  r!  r!  r"  r   :  s     zaten::gatherc           	   	   C  s   t |drt dd|S tj|}| jdtddgd}t	| || jdt|gd}| jd| jd	||||d
|
 d}| dt | ||d g|}t j| ||gddS )Nr~  r\   zsparse_grad == TruerU  r   r@  rV  r_  r  r  r`  rA  r&  )r   r  r  r   rd  re  r&  rW  r  r   ri  r  r)  )	r)  r:  rE   rm   Zsparse_gradrk  r  depthr   r!  r!  r"  r\   K  s    c                 C  s   t | ||||S r  )r   Z_var_mean_helper)r)  r0  rE   Z
correctionr  r!  r!  r"  	_var_mean_  s    r  z	aten::stdc                 G  s    t | |f| \}}| d|S Nr  r  r&  r)  r0  rP  r  rC  r!  r!  r"  r   e  s    z	aten::varc                 G  s   t | |f| \}}|S r  )r  r  r!  r!  r"  r  l  s    zaten::var_meanc                 G  s4   t |dkr t| |d |d d S t| |f| S d S )Nr@  r   )rN  r  )r)  r0  rP  r!  r!  r"  r  s  s    zaten::std_meanc                 G  s$   t | |f| \}}| d||fS r  r  )r)  r0  rP  r  r  r!  r!  r"  r   |  s    zaten::logsumexpc                 C  s   | j d|||dS )NZReduceLogSumExpr&  r.  r  r!  r!  r"  r     s    aten::arangec           
        s  t  r jd| S tjdd }tj fdd}t|dksNt|dkrt|dkr`d }n||d }t j |d	 |d
\}}}}t  |d	g}||}t  t	 t
 ||d d dg}	 jd|	t| dS t|dkst|dkrt|dkr
d }n||d }t j |d	 |d |d |d\}}}}t  |d	g}t  |d	g}t  |d	g}| d d|||}t  t	 t
 |d d d dg}	 d d|	||}	 jd|	t| dS t|dkr||d }t j |d	 |d |d\}}}}t  |d	g}t  |d	g}| d||} dt  t	 t
 ||f|dd   dg|}	 jd|	t| dS t ddt| dS )Nr   c                 S  s   t | d} | S )Nr~  )r   r  rm  r!  r!  r"  _get_arange_dtype  s    z!arange.<locals>._get_arange_dtypec                   s.   t | r* jd d| tjj d} | S )Nr_  r  r`  )r   rg  r&  r   rd  rc  ri  )range_tensorr+  r!  r"  _float_step_convert  s    


z#arange.<locals>._float_step_convertr  r|  r@  r   )rr  rn  r_  r`  r  r?  r:  )rq  rr  r  rn  r^  rI  r>  rA  r  )rq  rr  rn  r  r0  r1  )r   )r   r  r  r   rQ  rN  Z_arange_cast_helperr  r  r   r   r&  r   rd  ri  r  )
r)  rP  r  r  rn  rr  rq  r  r  Zarange_tensorr!  r+  r"  r     s    
	                     zaten::linspacec           
      C  sT   t | |d }t| t| ||t| || jdtjdtjdd}	t| t	| ||	|S )NrU  r@  rm  rV  )
r   Z_arange_helperrF   r   r&  rW  r   rr  r   r   )
r)  rq  rr  Zstepsrn  r  r  r   r  r  r!  r!  r"  r~     s    
 z
aten::liftc                 C  s   |S r  r!  r}  r!  r!  r"  rx     s    zaten::masked_fillc                 C  s6   | j d|tjjd}t|}|  d|t|||S )zImplement the masked_fill functionality available for a pytorch tensor in ONNX.

    Fills elements of the input tensor with `value` where `mask` is True.
    r_  r`  r  )r&  ra  rb  r  r   rF  r  r)  r:  maskrX  r!  r!  r"  r     s    
zaten::masked_fill_c                 C  s   t | |||S r  )r   r  r!  r!  r"  r     s    aten::indexc                   s,  t  rjd|ddS t |r0t |}n|g}tjfddfdd|D }t|dkrt jd	|d	 d
dS dd t	|D  t d	krS t dkrt
 d	 | d	  S t }|d krt ddS tdtj d t }tfddt|D jd  fddt|D  djd|d| d  } d  }t|d ddD ]@}d| |  |}	d||	}d| |  }qt
d	|t|}
 tt d	  d d krjdtdgdg fddt|D  }jd#|dd	i}t |ttd d	 d d	g tt d	 d || d  }jd|dfd dt d	 D |
g  fd!dt d	 |D  }jd$|dd	i}n,jd|
f fd"dt|D dd	i}t |S d S )%Nrm   r  )r  c                   sh   t | sdtj| tjjtjjks.t | rd jdk rDt	
dtd t  t | dg} | S )Nr  z?Exporting masked indices are only supported after ONNX opset 9.zExporting aten::index operator with indices of type Byte. Only 1-D indices are supported. In any other case, this will produce an incorrect ONNX graph.r@  )r   r  r   rd  re  rf  r  rK  r  r   r]  r  r  r  r   )rm   r}  r!  r"  try_mask_to_index  s(    
 
z index.<locals>.try_mask_to_indexc                   s   g | ]} |qS r!  r!  )r  ro  )r  r!  r"  r  %  s     zindex.<locals>.<listcomp>r@  r   F)Zapply_reshapec                 S  s   g | ]\}}t |s|qS r!  )r   r  )r  r~  ro  r!  r!  r"  r  =  s    
 r  zoperator of advanced indexing on tensor of unknown rank. Try turning on shape inference during export: torch.onnx._export(..., onnx_shape_inference=True).z=Exporting aten::index operator of advanced indexing in opset z is achieved by combination of multiple ONNX operators, including Reshape, Transpose, Concat, and Gather. If indices include negative values, the exported graph will produce incorrect results.c              
     s0   g | ](} j d  j dt|gdddqS )r  rU  rV  r   r  )r&  rW  r  r  rE   )r)  shape_tensorr!  r"  r  \  s   r  c                   s   g | ]}| kr|qS r!  r!  r-  )adv_idx_indicesr!  r"  r  j  s      r  r  r  r  r  rA  r>  rU  rV  c                   s   g | ]}| kr| qS r!  r!  r-  r  dim_tensor_listr!  r"  r    s     r2  r3  c                   s   g | ]} | qS r!  r!  r-  )r  r!  r"  r    s     c                   s   g | ]}| kr| qS r!  r!  r-  r  r!  r"  r    s   c                   s   g | ]}| kr| qS r!  r!  r-  r  r!  r"  r    s   )r2  )r2  )r   r  r  r  r  r   rQ  rN  r  rG  rl   r  r  r  r  r   r  r1  r  r&  r5  rW  r  r8  )r)  r:  rm   r  r  Zadv_idx_countZcum_adv_indexZ
multiplierr~  Z	adv_indexZcum_adv_index_shape_tensorZfolded_adv_idx_shape_listZfolded_adv_idx_shapeZadv_idx_permuteZfinal_shape_listZfinal_shaper!  )r  r  r)  r:  r  r  r"  rm     s    
       

	

  

 	zaten::linalg_normzOptional[Sequence[int]]r)  r:  ordrE   r  rn  c                 C  s   d }|d kr|t |r<t | |dg}| jdtdgd}t |}|d kr\t dd|S |dkrrt |d}qd	dg}n8t	|dkrt |r| jdtdgd}t |d}|rt
| |||||S t| |||||S )
Nr  rU  r  rV  rE   (Input rank must be known at export time.r@  rR  r   )r   r  r8  r&  rW  r  r  r  r`  rN  r|   rz   )r)  r:  r  rE   r  rn  	ord_valueself_dimr!  r!  r"  r{     s,    

  

zaten::linalg_vector_normc                 C  s   t | |||||S r  )r   Z_linalg_vector_norm_helperr  r!  r!  r"  r|     s    zaten::linalg_matrix_normz	List[int]c              	   C  s  t |d}|dkr"t| |||S |dkr8t dd|S t |d}|d krZt| |||S |dksj|dkrxt dd	|S t |}|d krt dd
|S |d dk r|d  |7  < |d dk r|d  |7  < |tjks|tj kr|d |d  |d< |d< |d |d kr*|s*|d  d8  < t j| | d||d g|d}|dkrt	| || jdt
|d gd|d\}	}
n*t| || jdt
|d gd|d\}	}
|	S d S )NrY  ZfroZnuczlinalg.matrix_normzord==nucrR  r  rh  zord==2r  r   r@  r  r&  rU  rV  )r  r  )r   r`  rY   r  r  r  infr)  r&  r   rW  r  r   )r)  r:  r  rE   r  rn  r  r  r  r  r  r!  r!  r"  rz     sZ    
   
  

zaten::linalg_crossr  c                 C  s   t | |||S r  )rB   )r)  r0  r=  rE   r!  r!  r"  ry     s    zaten::frobenius_normc                 C  s,   |  d||}tj| |||d}|  d|S )NrA  r&  r  )r&  r   r)  )r)  r:  rE   r  ZsqrZsumsqrr!  r!  r"  rY   &  s    zaten::multinomialc                 C  sZ   |d k	r t |s t dd| |s:|dkr:t dd| t| |}| jd|tjj|dS )NZMultinomialz*generator is not supported for multinomialr@  zGreplacement=False when num_samples > 1 is not supported for multinomial)r  Zsample_size_i)r   r  r  r   r&  ra  rb  rc  )r)  r0  Znum_samplesreplacementr  Z	log_inputr!  r!  r"  r   /  s&      
zaten::baddbmmc           
      C  s\   t j|}t| ||}t| || jd|| d}t| || jd|| d}	t| ||	S r  )r   rd  re  r   r   r&  ri  r   )
r)  r:  Zbatch1Zbatch2r  rH  rk  Z	batch_mulZmul_aZmul_br!  r!  r"  r%   I  s    zaten::meshgridzOptional[str])r)  indexingc                   s0  |d krd}n|dkr(t d| |t|}|dkrP|dd d |d d<  fdd	|D } fd
d	|D } jd|ddi}g }t|D ]h\}}	 jdtjdtjddgt	| }
|| |
|< t
 |	 jd|
ddi}| d|| q|dkr"|d |d  |d< |d<  jd| S )Nij>   r  xyzUnsupported indexing: r  r@  r  r  c                   s,   g | ]$}t  | jd tdgdqS )rU  r  rV  )r   r8  r&  rW  r  r  r+  r!  r"  r  h  s     zmeshgrid.<locals>.<listcomp>c                   s   g | ]}  d |qS )r-  r.  r  r+  r!  r"  r  n  s     r2  r3  r   rU  rm  rV  r  prim::ListConstruct)r2  )r2  )r  )r   r]  r   r  r&  rG  rW  r   rr  rN  r7  r  )r)  r  r  Zunpacked_tensor_listr  Ztensors_shapeZ	out_shaperj  r~  r   r  Z
t_reshapedr!  r+  r"  r   [  s4     


zaten::remainderc                 C  s(   t | ||}| d||}| d||S )NrA  rI  )r[  r&  )r)  r0  r=  rF   Zquor!  r!  r"  r   }  s    z
aten::gelu)r)  r:  approximatec                 C  s&  |dkrt dt j }d}tj|tjd}tj|tjd}tjdtjd}tjdtjd}t| |t| ||}	t| |t| |t| ||	}
t| |t| |t| || d|
S d}| d	| d
|tj|tjd}t| || jdtjdtjdd}t| t| ||| jdtjdtjddS d S )Nr   r  gHm?rm  rS        ?r  g;f?r  r^  rU  r@  rV  )	r  r   r  rW  r   ry  r   r   r&  )r)  r:  r  ZkBetaZkKappar  kapparu  ZhalfZ	self_cubeinnerZ_sqrt2rO   Zerf_plusoner!  r!  r"  r^     s,    $"  
zaten::group_normc              
   C  s  t  r | jd||||||dS t |d}|d k	rD|| dksDtt |}|d krdt dd|S d|dg}	t | || jdt	
|	d}
| jdt	jd	g| tj| d
d}| jdt	jdg| tj| d
d}| jd|
|||d}t | || d|}|d ks"|  rLt	jd	gtj| d
}| jd|d}|d ksd|  rt	jdgtj| d
}| jd|d}ttd|d }t| t| |t | ||t | ||S )Nra   )Znum_groups_ir  Zcudnn_enabled_ir@  r   zunknown input rankr  rU  rV  rS  rm  r  r  r  r-  )r   r  r  r  rz  r  r  r8  r&  rW  r  r   r   rd  re  rn  r  
mustBeNoner5  r  r   r   r  )r)  r0  Z
num_groupsr  r  r  r  r  Z
input_rankr6  r  r  r  Znorm_reshapedr   r  r  r  r!  r!  r"  ra     s|    


        zaten::_weight_normc                 C  s   t |}|d k	rttt|}|d k	rH|dk r6||7 }|dkrH|| t| |d|d}| d||}| d||S t  r| jd|||dS t	
d|d S )	Nr  r  r@  r^  rA  _weight_normr  zDUnsupported: ONNX export of _weight_norm for tensor of unknown rank.)r   r  r5  r  remover   r&  r  r  r   r]  )r)  rz  Zweight_grE   r  r  Znorm_vrF   r!  r!  r"  r    s"    

r  z	aten::dimc                 C  s   |  d|}|  d|S )zFImplement the dim functionality available for a pytorch tensor in ONNXr-  Sizer.  r9  r!  r!  r"  rE     s    zaten::__contains_c                 C  sd   t |}tdd |D rTt |rT| jdtt | ddd |D kdS t	
d|d S )Nc                 s  s   | ]}t |V  qd S r  )r   r  r^  r!  r!  r"  r    s    z__contains_.<locals>.<genexpr>rU  rX  c                 s  s   | ]}t | d V  qdS )rX  N)r   r   r  r^  r!  r!  r"  r    s     rV  zJUnsupported: ONNX export of __contains__ for non-constant list or element.)r   r  r  r  r&  rW  r   r   r  r   r]  )r)  r:  elementZunpacked_listr!  r!  r"  __contains_  s$    
r  zaten::__getitem_c                 C  s    t | || jdtdgd|S r  )r   r&  rW  r   )r)  r:  r~  r!  r!  r"  
__getitem_&  s    r  z
aten::itemc                 C  s   |S r  r!  r}  r!  r!  r"  rr   ,  s    z
aten::takec              
   C  sD   t | || jdtjdgtjdd}t| |d|}t| ||}|S )NrU  r  rm  rV  r   )r   r8  r&  rW  r   rr  rl   r   )r)  r:  rm   Zself_flattenedrj  r!  r!  r"  r   2  s      c                 C  s&   t | ||}t| |}t| ||}|S r  )r   rP   r   )r)  r0  targetdiff_Zexp_r[  r!  r!  r"  _kl_div_log_target_impl=  s    
r  c           	      C  sZ   t | |}t| ||}t| ||}t| |}t| || jdtdd}t| |||}|S r  )	r   r   r   r  rb   r&  rW  r   r  )	r)  r0  r  Zlog_r  Z
output_posZzeros_Zmask_r[  r!  r!  r"  _kl_div_non_log_target_implE  s    

r  zaten::kl_divc                 C  sj   |rt | ||}nt| ||}|dkr*|S |dkrB| jd|ddS |dkrZtj| |ddS td|S d S )Nr   r@  r  r'  r  z4kl_div with reduction other than none, mean, or sum.)r  r  r&  r   r)  r  )r)  r0  r  	reductionZ
log_targetr[  r!  r!  r"  rs   P  s     zaten::mse_lossc                 C  sh   t | t| ||t| ||}|dkr(|S |dkr@| jd|ddS |dkrXtj| |ddS td|S d S )Nr   r@  r  r  r  z6mse_loss with reduction other than none, mean, or sum.)r   r   r&  r   r)  r  )r)  r0  r  r  r[  r!  r!  r"  r   e  s     zaten::as_stridedc                 C  s  t |d}t|}t | || jdtjdgtjdd}t |stjdgtj	d}t
t||D ]6\}\}	}
dg| }d||< |t|	||
  }qd|r|| }| d|| jd|dS d }t
|D ]\}}
dg| }d||< t| || jdtdgd| jdt|d}	t | t| |	d	d d d | jdt|d}| d
|| jdt|
gd}|d krr|}q| d||}q|r| d|| dt|g}| d||S d S )Nr  rU  r  rm  rV  r   r@  r  r  rA  r>  )r   r  rN  r8  r&  rW  r   rr  rB  r  rG  rF  r   r  r   )r)  r:  r8  stridesoffsetr  Zself_1dindr~  r   r2  Zr_sizeZtmp_indr!  r!  r"  r    v  sT      


  
zaten::__derive_indexc              	   C  s   |  d||  d||S )Nr>  rA  r.  )r)  rm   rq  r  r!  r!  r"  __derive_index  s    r  zaten::__range_lengthc                 C  s6   |  d||}|  dt| ||}| j d|tjjdS )NrI  r  r_  r`  )r&  r  ra  rb  rc  )r)  lor  r  r   rF   r!  r!  r"  __range_length  s    r  zaten::linearc                 C  s   t |}t| |}|dkrp|  sp| jdtjdtjdd}| jdtjdtjdd}t	| |||||}n$t
| ||}|  st| ||}|S )Nr  rU  r@  rm  rV  )r   r  r   r  r  r&  rW  r   rr  r   r   r   )r)  r0  r  r  r  rH  r  r[  r!  r!  r"  r}     s    

zaten::hann_windowzOptional[int])r)  rn  c              	   C  s   |d kr.t  }|r|js t j}tj|}	n
t|}	t| |dd d d }
| jd|
t	j
jd}t| | jdt jtjt jdd|}|dkrt| || jdt jdt jdd}t| ||}| jdt| t| ||	 d}|S )	Nr  r_  r`  rU  rm  rV  Fr@  )rW  rw  ro   rx  r   rd  Z
from_dtyper   r&  ra  rb  rh  r   r   r  r  r   r  rF   r   r   ri  )r)  Zwindow_lengthZperiodicrn  r  r  r   r%  Zdtype_rk  Zn_arrayr[  r!  r!  r"  rc     s4    

    zaten::mvc                 C  s   t | ||S r  r   )r)  r:  Zvecr!  r!  r"  r     s    z	aten::dotc                 C  s   t | ||S r  r  rL  r!  r!  r"  rG     s    zaten::movedimc           
      C  s   | d}| d}| | ks(t||k r8|S t|}|d k	sNttt|}| }| }t	|
 |
 D ] \}}	|||	< d||< d||	< q|dd |D }dd |D }t	||D ]\}}	|||	< q| jd||dS )Nr  c                 S  s   g | ]}|d kr|qS r  r!  r  r!  r!  r"  r    s      zmovedim.<locals>.<listcomp>c                 S  s   g | ]}|d kr|qS r  r!  r  r!  r!  r"  r    s      r  r  )r  r   rz  r  r   r  r5  r  r  rF  tolistr&  )
r)  r:  r  Zdestinationr  r  Zsrc_dimsZdst_dimsr  dstr!  r!  r"  r     s&    




z
aten::fillc                 C  s    t j|t jj}t| |||S r  )r   rd  re  rh  rZ   )r)  r:  rX  rk  r!  r!  r"  rT   %  s
     zaten::index_addc                   s  t d |r0tt|dkr0tdd|S t d  d krPtd|t	|}t	|}|d kst|d krtd|||kr|| }t
|D ]}	t| |t	|g}qt| }
t| }|
d k	r|d k	r|
|krtd|tt
|}d	d
 t
|D } fdd
t
|D }tj| ||||d}t| ||}t
 D ]}	t| |dg}qLt
|  d D ]}	t| |t	|g}qtt| | t| |||S )NzyWarning: ONNX export does not support duplicated values in 'index' field, this will cause the ONNX model to be incorrect.r@  rh   z
alpha != 1r~  zXONNX export does NOT support exporting 'index_add_()' function with unknown 'dim' value.z~ONNX export does NOT support exporting 'index_add_()' function while the rank of self tensor or tensor to be added is unknown.zoONNX export does not support exporting 'index_add_()' function with duplicated values in 'index' parameter yet.c                 S  s   g | ]}d qS rY  r!  r-  r!  r!  r"  r  c  s     zindex_add.<locals>.<listcomp>c                   s   g | ]}| krt jnd qS r3  )sysmaxsizer-  rg  r!  r"  r  d  s     r  r   )r  r  r   rE  rF  r  r  r   r]  r  r  r  r  r5  r  rQ   r   )r)  r:  rE   rm   r=  rH  Zself_dim_rankZother_dim_rankdeltar~  Zother_dim_sizeZself_dim_sizeZnew_shape_axesZnew_shape_startsZnew_shape_endsr  r!  rg  r"  rh   /  sl    

  
      
z
aten::rollc                 C  s   t |t |kst|}tt |D ]~}g }tj| ||| g||  gtjgd}|| tj| ||| gdg||  gd}|| | jd|d|| i}q$|S )Nr  r   r2  r3  )r2  )	rN  rz  r  r   r  r  r  r  r&  )r)  r:  Zshiftsr  r  r~  r  r6  r!  r!  r"  r   v  s,       
 
    

zaten::crossc                 C  sp   t ||}t| |dg|g}t| |dg|g}t| |dg|g}t| |dg|g}t| t| ||t| ||S )Nr  r@  )r   Z_get_dim_for_crossr   r   r   )r)  r0  r=  rE   Zroll_x_1Zroll_y_1Zroll_x_2Zroll_y_2r!  r!  r"  rB     s    zaten::cdistr  #use_mm_for_euclid_dist_if_necessaryc                 C  sR   t |}|d k	stt | ||d g}t | ||d g}t| |||dddS )Nr@  r  gư>F)r  r  )r   r  rz  r  r   )r)  r  r  r  Zcompute_moder  Zbroadcasted_x1Zbroadcasted_x2r!  r!  r"  r/     s    
     z
aten::lerpc                 C  sx   |  d||}t| |  d|| j dtdd|  d||  d|||  d||  d||  d| j dtdd|S )	NrI  r  rU  r  rV  r>  rA  rS  )r&  r  rW  r   )r)  r:  rr  r  diffr!  r!  r"  rw     s    zaten::broadcast_tensorsc                   sP   t |}t |d |D ]}t |q fdd|D } jd| S )Nr   c                   s   g | ]}t  |qS r!  )rQ   r  r)  Zt_with_final_shaper!  r"  r    s     z%broadcast_tensors.<locals>.<listcomp>r  )r  )r   r  r  r   r&  )r)  r:  Zall_tensorsr   Zt_listr!  r  r"  r+     s    
zaten::is_pinnedc                 C  s   d S r  r!  )r)  r:  r  r!  r!  r"  rp     s    prim::ConstantSplitc                 C  s^   t ||}|d kr"t dd|S |g||  }|| }|rF|| | jd|||t|dS )Nr  r  r  r  )r   r  r  r  r&  rN  )r)  r:  r  rE   r   r  r  r!  r!  r"  r     s      
prim::ConstantChunkc                 C  s@   t ||}|d kr"t dd|S || d | }t| |||S )Nr  r  r@  )r   r  r  r   )r)  r:  r  rE   r  r  r!  r!  r"  r     s      zprim::shapec                 C  s   |  d|S r,  r.  r}  r!  r!  r"  r     s    z	prim::maxc                 C  s   t j| d||ddS )Nr  r  r  r  rL  r!  r!  r"  r     s        z	prim::minc                 C  sB   |s6t |r,t| || jdtdgd}t| |S t| ||S r  )r   r  r   r&  rW  r   r   rL  r!  r!  r"  r     s
    

z
prim::datac                 C  s   |S r  r!  r}  r!  r!  r"  r     s    zprim::layoutc                 C  s   | j dtddS r  r  r}  r!  r!  r"  r      s    r  c                 O  s   d S r  r!  r)  ru  rH  r!  r!  r"  r   '  s    zprim::ListUnpackzOptional[List[_C.Value]])r)  r  c                 O  s2   t |dkr.|d   dkr.t|d S d S )Nr@  r   r  )rN  r  r  r   r  r  r!  r!  r"  r   -  s     zprim::TupleConstructc                 O  s   d S r  r!  r  r!  r!  r"  r   :  s    zprim::Uninitializedc                 O  s   d S r  r!  r  r!  r!  r"  r   @  s    zprim::unchecked_castc                 C  s   |S r  r!  r}  r!  r!  r"  r   J  s    zprim::dtypec                 C  s.   t |}|d krtjj}| jdt|dS rT  )r   r  r   rd  rh  r&  rW  r   )r)  r:  rk  r!  r!  r"  r   P  s    
prim::tolistc                 C  s&   t |d}|dkr"t dd|S |S )ztolist is currently supported only for 1D input tensors.

    dim_val and elem_ty_val represent dimension and type annotations
    that need to match dimension and type of the input tensor.
    r~  r@  r  zdim_val > 1)r   r  r  )r)  r0  Zdim_valZelem_ty_valrE   r!  r!  r"  r   Z  s    r>  Nonec                 O  s>   | j   }t|tjrd S tdd|  d| j  S )Nr>  z,output type should be 'DeviceObjType', not '')	original_noder[  re  r4  r	   r  r   r  r  )r)  ru  rH  output_typer!  r!  r"  r   k  s    z
prim::LoopzList[_C.Value]c              	   O  s(  | j }| j}| j}| j}tj}tj}t| }	t	j
| df|| t|	d\}
}}t|	|D ]\}}t| D ]l\}}|dkr|t|k r|||   |dkrv|d t|k rvt| tjsv|||d    qvtj||j|||d qbtj||}tjr$tj||| |S )NZLoopr  Zn_blocksr   r@  F)r  envvalues_in_envparams_dictr   r4  r  rE  blocksr   add_op_with_blocksoutputsSizerN  rF  rG  ru  r'  re  r4  r	   r(  rW  _jit_pass_onnx_blockblock%_jit_pass_fixup_onnx_controlflow_nodeonnx_shape_inference(_jit_pass_onnx_node_shape_type_inference)r)  ru  attrsr  r  r  r  r4  opset_version
old_blocksnew_op_outputsnew_block_contextsnew_node	old_blocknew_block_contextr~  Zb_infixed_outputsr!  r!  r"  r   y  s\         zprim::Ifc              	   O  s  | j }| j}| j}| j}| j}tj}tj}	|d  	 dk}
|
rt
|d  d }t|trnt|nt|}|r~dnd}t| | }tj|||||d}t| }t| }g }tt|D ]B}|| |krtd||  d|| |||  }|| q|S t| }tj| df|| t|d	\}}}t||D ]"\}}tj||j|||d
 qVtj ||	}tj!rtj"|||	 |S d S )Nr   r  rX  r@  TzThe sub block ATen output z is not in env.Ifr  F)#r  r  r  r  r  r   r4  r  r  r  r   r   r  r4  r5  r  r  r  rW  r	   r  r  r  rN  r   r]  r  rE  r   r  r  rF  r   r  r  )r)  ru  r  r*  r  r  r  r  r4  r  Z	static_ifZ
input_flagr,  Z	block_idxZ	current_bZif_output_listZcurrent_b_listZfinal_b_listro  Zonnx_br  r  r  r  r	  r
  r  r!  r!  r"  r     s         r%  c                   s*   j }| rd S t|  tjr*d S |ddkrN jdt	
|ddS |ddkrr jdt	
|ddS |  tj s|  tj r jdtt	
|ddS |  tj r  fddt	
|dD } jd| S td
|d dtj d| d S )NrX  r   rU  rV  rY  Zvalue_sc                   s   g | ]} j d |dqS )rU  r  r.  )r  rY  r+  r!  r"  r  "  s   z!prim_constant.<locals>.<listcomp>r  z"Unsupported prim::Constant kind: 'z'. Please send a bug report at .)r  )r  r  r4  r[  re  r	   r  r  r&  r   r   r@  rA  rB  ZofFloatsrW  r   Z	ofStringsr   r]  r
   ZPYTORCH_GITHUB_ISSUES_URL)r)  ru  r  r  Zstr_constantsr!  r+  r"  r     s6     


prim::type)r)  device_valuec                 O  sJ   |   dkr<t|   }|d k	r<| jdt|dS tdd|S )Nr>  rU  r  r  z,Device type cannot be statically determined.)	r  r  r   Zget_device_from_valuer0  r&  r  r   r  )r)  r  rP  rH  r  r!  r!  r"  r   /  s    zonnx::Placeholderc                 O  s*   | j }| j}| j}| j}tj||||S r  )r  r  r  r  rW  r	   Z'_jit_onnx_convert_pattern_from_subblock)r)  ru  r  r  r  r  r  r!  r!  r"  r   >  s       zaten::resolve_conjzaten::resolve_negr/  c                 C  s   |S r  r!  r/  r!  r!  r"  r
  K  s    	zaten::_conjzaten::conj_physicalc                 C  s    t |rt d|S t| |S )Nz aten::_conj, aten::conj_physical)r   Zis_complex_valuer  r
  r/  r!  r!  r"  r	  W  s    	
zaten::logit)r)  r:  r  c                 C  s   | j dtdd}t|s| j d|tj| d}|  d||}|  d||}|  d|||}|  d	||}|  d|||}n|}|  d||}	|  d
||	}
|  d|
S )NrU  rS  rV  r_  r`  rI  r  r  r  r^  r  )	r&  rW  r   r   r  r   rd  re  ri  )r)  r:  r  ru  Zone_sub_epsZself_less_equal_one_sub_epsZtemporary_selfZtemporary_self_less_epszr   rF   r!  r!  r"  r   k  s     
  )N)N)N)rS  )T)N)N)N)N)N)N)r   N)N)F)N)N)NNN)N)N)FF)NN)NN)N)FN)NNNFN)F)NNF)NN)F)NNNFN)F)F)NNNFN)F)F)NNNFN)F)N)N)NN)NN)NNFN)NNFN)NNN)N)F)r  )NF)FN)N)r  )N)TNNNNF)N)N)r  r  )N)N(r  __doc__
__future__r   rm  r  r  r  r  typingr   r   r   r   r   r   rW  Ztorch._C._onnxr	   Z_onnxra  Ztorch.nn.modules.utilsZ
torch.onnxr
   r   r   r   r   Ztorch.onnx._globalsr   Ztorch.onnx._internalr   r   r   Ztorch.typesr   r  partialZonnx_symbolicZ_onnx_symbolicr$  rQ  r  r1  r7  rO  r   r   r   r   r   r   rF   rP  r   rO  r\  r[  rV   rX   r  r   r.   r   r  r   r*   r   r   r   r   r   r   r   r@   r   r"   r   r#   r$   r   r   r  Z_apply_paramsr  rC   r  r  r   r   rR   r,   rQ   rK   rJ   r   r   r   r  r  r  r   r  r   r  r  r   r   r   r   r   r   r   r   r0   rW   r  r   rv   r`   r   r   r_   nnmodulesutilsZ_singleZ_pairZ_triplerR  r   r   r   r]  rf  rj  rl  r5   ru  r   r   r   r  r  r(   r)   r  r  r  rN   r   rb   r  r   r  r]   ru   r  r  r  r   r   r   r   r  r  r  r   r  r  r  r?   r;   r<   r=   r8   r9   r:   r&   r   rt   rn   r  rI   r   rl   rk   rj   ri   r-   r  rA   r   r4   r   r   r   r   r   r3   r2   r1   r   r   r   r   r   r   r   rP   rH   r  r   r7   r  r  
deprecatedr  r  r  r  r  r  r  r  r  rM   rL   r   r   r   r!   r  r  r   r  r   r   r   r[   rZ   r   rS   r   rg   rf   re   r   rd   r   r   r  r   r   r   r>   r   r   r   r   r   rw  r|  r~  r   r   r  r  rD   r6   r  r  r   r   r   r   r   r   r   r'   r   rO   rU   r   r   rq   r  r  r   r   r   r   r   r   ro   r  r  r   r\   r  r   r  r  r   r   r   r~   rx   r   r   rm   r{   r|   rz   ry   rY   r   r%   r   r   r^   ra   r  rE   r  r  rr   r   r  r  rs   r   r    r  r  r}   rc   r   rG   r   rT   rh   r   rB   r/   rw   r+   rp   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r
  r	  r   r!  r!  r!  r"  <module>   sb       


&5
    *	






*!


>

	
*
	>5 
  	 
  	 
  	;



8  
	  
	  
	


7 )		(,

             C	7*"#"#"#&&&*74J8**_$


  #


	$     	         
    	    	H
(



K  gFC  f** C*      




(	(	 Q
 *0!00;    $(E	

",	"     "& E  
4Z""	 