U
    Mht                    @   s!  U d dl mZ d dlZd dlZd dlmZ d dlmZmZ d dl	m
Z
 d dlmZ d dlZd dlZd dlmZ d dlm  mZ d dlmZ d dlmZmZmZmZmZmZmZmZ d d	lm Z m!Z! d d
l"m#Z#m$Z$ d dl%m&Z& d dl'm(Z( d dl)Zd dl*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0m1Z1 ej2Z2dZ3dd Z4dd Z5e6dddddd ddej7de6ddddddd ddej7d 	e6d!d"d#ej7d$e6d!d%d&d'd(d#ej7d)e6d*d+d,d ej7d-e6d.d/d0d1d#d#ej7d2gZ8d3d4 Z9d5d6 Z:d7d8 Z;d9d: Z<d;d< Z=d=d> Z>d?d@ Z?dAdB Z@dCdD ZAdEdF ZBdGdH ZCdIdJ ZDdKdL ZEdMdN ZFdOdP ZGdQdR ZHdSdT ZIdUdV ZJdWdX ZKdYdZ ZLd[d\ ZMd]d^ ZNd_d` ZOdadb ZPdcdd ZQdedf ZRdgdh ZSdidj ZTdkdl ZUdmdn ZVdodp ZWdqdr ZXdsdt ZYdudv ZZdwdx Z[dydz Z\d{d| Z]d}d~ Z^dd Z_dd Z`dd Zadd Zbdd Zcdd Zddd Zedd Zfdd Zgdd Zhdd Zidd Zjdd Zkdd Zldd Zme; e< e> e@ eA e= e? eB eC eD eE eF eG eH eI eJ eK eL eM eN eO eP eQ eR eS eU eT eV eX eW eY eZ e[ e\ e] e^ e_ e` ea eb ec ed ee ef eg eh ei ej ek el e6dddddddej7de6ddddddddej7d	e6ddddddddej7d	e6ddddddddej7d	e6ddddddddej7d	e6ddddddddej7d	e6ddddddddde6ddd ddddej7de6ddd dddddej7de6ddd dddddej7de6ddd dddddej7de6ddd dddddej7de6ddd dddddej7de6ddd dddddej7dЍe6ddddddddej7d	e6ddddddddej7d	e6ddd dddddej7dЍe6ddddddddej7dߍ	e6dddd1dddddej7d
e6dddd1dddddej7d
e6dddddddddej7d
e6dddddddddej7d
e6ddddddddde6ddd ddddddej7d	e6ddd dddddej7de6ddd dddddej7de6ddd dddddej7de6dd d dddddej7de6ddddddddej7d	e6dddd	dddddej7d
e6dd
dd	dddddej7d
e6ddd ddddddej7d	e6ddd ddddej7de6ddd ddddej7de6ddd ddddej7de6ddd ddddej7de6ddd dd ddej7de6d!d"d#d$dddd%ej7dߍ	e6d!d&d'd(ddddd%ej7d
e6d!d)d'd(dd*d#dd%ej7d
e6d!d+d,d-ddddd%ej7d
e6d!d.d/d-dd0ddd%ej7d
e6d!d1d2d3ddddd4e6d5d6d d7d$ddddej7d	e6d8d9d d:d-dd%ej7de6d;d<d d=d-dd%ej7de6d>d?d d@dAddd%ej7de6dBdCd dDdAddd%ej7de6dEdFd dGdAddd%ej7de6dHd"dIdd$dd%ej7dJe6dHdKdLdd$ddd%ej7dM	e6dNdOdPdQej7dRe6dNdOdPdSemdTej7dUe6dNdOdPdVd ddWdXe6dYdZd[d\d d#ej7ed]d^e6dYdZd[d_d d#d`ej7ed]dae6dbdZdcddd d#deej7dfe6dbdZdcdgd d#d`ej7dfe6dbdhdidjd d#dkej7dfe6dbdldmdnd d#doej7dfe6dpdqd drdsd d#ej7dte6dudvd dwdxd d#ej7dte6dydzd d{d|d d#ej7dte6d}d~d ddd d#dde6dd ddd dd#dde6ddddej7dRe6ddddej7dRe6e:ejndddddddd#ej7de6e:ejndddddddd#de6e:ejndddddddd#ej7de6e:ejndddddddd#ej7de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddddd#de6e:ejnddddddddd#ej7de6e:ejnddddddddd#ej7de6e:ejndddddddd#ej7de6e:ejndddddddd#ej7de6e:ejndddddddd#ej7de6e:ejndddddddd#ej7de6e:ejndddddddd#de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddddd#de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddÐddd#ej7de6e:ejnddŐdd#ddƐddd#ej7de6e:ejnddȐdd#ddɐddd#ej7de6e:ejnddddddːddd#ej7de6e:ejnddȐdddd͐ddd#ej7de6e:ejndddd#ddАddd#ej7de6e:ejndddd#ddАddd#de6e:ejndddd#ddӐddd#ej7de6e:ejndddd#ddՐddd#ej7de6e:ejnddŐdd#ddאddd#ej7de6e:ejnddȐdd#ddِddd#ej7de6e:ejnddddddېddd#ej7de6e:ejnddȐddddݐddd#ej7de6e:ejndddddߐddd#ej7de6e:ejndddddߐddd#de6e:ejndddddddd#ej7de6e:ejndddddddd#ej7de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddddd#de6e:ejndddd#ddddd#ej7de6e:ejndddd#ddd(ddd#ej7de6e:ejnddddddddd#ej7de6e:ejndddddddddd#ej7de6e:ejodddddd#ej7de6e:ejod ejpddddd#d#ej7de6e:ejod ddddd#ej7de6e:ejod dddd	d#ej7de6e:ejod ejpdddd
d#d#ej7de6e:ejod ddd+dd#d#ej7de6e:ejodddd+dd#d#ej7de6e:ejodddddd#d#de6e:ejqdddddd#ej7de6e:ejqd ddddd#ej7de6e:ejqd ddddd#ej7de6e:ejqd ddd+dd#ej7de6e:ejqdddd+dd#ej7de6e:ejqd ddddd#de6ddd d d!d#dej7d"e6d#d$d d%d&d#dej7d"e6d'd(d d%d)d#emdej7d*e6d+d,d d-d!d#dej7d"e6d.d/d d0d&d#dej7d"e6d1d2d d0d)emd#dej7d3e6d!d%d&dd4d#ej7d)e6d5d6d ej7d7e6d5d8d d9ej7d:e6d5d;d d<ej7d:e6d5d=d>d?d d@ej7dAe6d5dBd emdTej7dCe6dDdEdFdGdHddId#d#ej7dJ
e6dDdKddLdMejrfdNdGd#dOde!rdPnd%ej7dQ	e6dRdSdTdUd d#dHdd%ej7dV	e6dRdKddLdMejrfdWdXd d#dOdd%ej7dV	e6dYdKddddLdMejsfdZd[d d#d\de!rnd%nd]ej7dV	e6dd^d_d`d dad dTddej7db	e6d*dcdddeemdTej7dfe6dgdhdietddgfdeemdTej7dfe6djdkdkdkgdd#fdldmdddddndo	gZuevdpdqdrdsgdtdudvdwgD ]
\ZwZxdD ]Zyeydkrewdpkrqpeze{d eyd  Z|dxdy}e~ee| dz Zd{dey  Zd|ezd}d~ e|D  Zeue6dey ddddde|d d dewf	dey de dex deedew ddd%ej7d
 qpq`ddddddddddd!dddddddddddddgZdej7dddid#d#ej7dd#d#ej7dd#ej7ddej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idej7idZe+ee6f ed< eD ]>Ze6edemdTd#dZeei Zee eue qdddZdddZdddZdddZdddZdddZdddZd ddZdd Zd!ddZd"ddZd#ddZdd Zd$dÐdĄZd%dŐdƄZd&dȐdɄZd'dʐd˄Zd(d̐d̈́Zeeedd΍eeeeeeeeeeeeedϜZe+de,f ed< g ZdҐdӄ ZdԐdՐd֐dאdgZdِdedkgZe
eeD ]L\ZZe6e de edۜdddded#ej7dݍZee qeD ]DZe6de dd dd dd ed#ej7dZee qddd dd fddd dd fddd dd fddd dd fddd dd fddd dd fddd dd fddd dd fddd dd fddd dd fg
Zdd d#iiZe+ee6f ed< d#d#d#d#d#dZdِdedkgZe
eeD ]~\\ZZZZe6e de edۜddefddefddedeeddZeei Zee ee  qG dd deZG d	d
 d
ZG dd deZG dd dZG dd deeZG dd deeZd)ddZd*ddZdd ZdS (+      )abstractmethodN)deepcopy)reducepartial)product)mul)
_reduction)TestCaseto_gpufreeze_rng_stateis_iterable	gradcheckgradgradcheckset_default_dtypeskipIfTorchDynamo)	TEST_CUDASM90OrLater)_get_numerical_jacobian_iter_tensors)Variable)_TensorOrTensors)DictCallableTupleListSequenceUnionAnyh㈵>c                 C   s<   t | dd }|d kr,tjt | dd ddd}|d k	s8t|S )N	reductionZsizeAverageTF)Zemit_warning)getattr
_ReductionZlegacy_get_stringAssertionErrormresult r&   S/var/www/html/venv/lib/python3.8/site-packages/torch/testing/_internal/common_nn.pyget_reduction    s
    r(   c                 C   s$   t | dd }|d k	r|S t | dd S )Nweightweights)r    r#   r&   r&   r'   
get_weight(   s    r+   ZLinear)
      ztorch::nn::LinearOptions(10, 8))   r,   c                 C   s,   t | |d  |d dddd S )Nr      r.   r-   )torchmmtviewexpandip_r&   r&   r'   <lambda>k       r:   T{Gzt?)module_nameconstructor_argscpp_constructor_args
input_sizereference_fn	with_tf32tf32_precisiondefault_dtype)r,   r-   Fz+torch::nn::LinearOptions(10, 8).bias(false)Zno_biasc                 C   s   t | |d  S )Nr   )r1   r2   r3   r6   r&   r&   r'   r:   v   r;   )	r=   r>   r?   r@   descrA   rB   rC   rD   RReLU)r/      rG   F)r=   r@   	test_cudarD   )皙?g?z/torch::nn::RReLUOptions().lower(0.1).upper(0.9))r.   r.      Zwith_up_down)r=   r>   r?   r@   rE   rH   rD   ZFlattenrG      r.   rJ   c                 G   s   t | dS Nr/   )r1   flattenr7   r9   r&   r&   r'   r:      r;   )r=   r@   rA   rD   ZCrossMapLRN2d)rJ   r<   MbP?rG   z>torch::nn::CrossMapLRN2dOptions(5).alpha(5e-3).beta(1e-3).k(2))rG   rL      rQ   )r=   r>   r?   r@   check_gradgradcheck_batched_gradrD   c                  G   s    t t| d}t|j|   S rM   )r   r   r1   randpermr4   double)sizetotalr&   r&   r'   _rand_tensor_non_equal   s    rX   c                    s   G  fdddt j}|S )Nc                       s   e Zd Z fddZdS )z)wrap_functional.<locals>.FunctionalModulec                    s
    |S Nr&   )selfargsfnkwargsr&   r'   forward   s    z1wrap_functional.<locals>.FunctionalModule.forwardN)__name__
__module____qualname__r_   r&   r\   r&   r'   FunctionalModule   s   rc   )nnModule)r]   r^   rc   r&   r\   r'   wrap_functional   s    rf   c                
      sD   t dd tdt fddddd d d fd	dd
t jdS )Nr,   ZPoissonNLLLoss_no_reducec                    s   t j|  | ddS Nnoner   )FZpoisson_nll_losstype_asr7   r3   r&   r'   r:      r;   z/poissonnllloss_no_reduce_test.<locals>.<lambda>zaF::poisson_nll_loss(i, t.to(i.options()), F::PoissonNLLLossFuncOptions().reduction(torch::kNone))c                   S   s   t ddS Nr,   r1   randr&   r&   r&   r'   r:      r;   _get_input()r7   r3   c                    s   |    |  S rY   )expr   rO   rm   r&   r'   r:      r;   Ffullnameconstructorcpp_function_callinput_fncpp_var_maprA   picklerD   r1   randndictrf   rU   r&   r&   rm   r'   poissonnllloss_no_reduce_test   s    

r~   c                      sX   t tdddtj tdt fddddd d	 d
 fddddtjd	S )N   r,   r   ZBCELoss_no_reducec                    s   t j|  | ddS rg   rj   Zbinary_cross_entropyrk   rl   rm   r&   r'   r:      r;   z(bceloss_no_reduce_test.<locals>.<lambda>iF::binary_cross_entropy(i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().reduction(torch::kNone))c                   S   s   t ddddS Nr   r,   y&1?v?r1   rp   clamp_r&   r&   r&   r'   r:      r;   rq   rr   c                    s"    |    d  d|       S rM   logrO   rm   r&   r'   r:      r;   FgǺF?	ru   rv   rw   rx   ry   rA   rz   	precisionrD   )r   r1   r|   gttorU   r}   rf   r&   r&   rm   r'   bceloss_no_reduce_test   s    

r   c                
      sP   t ddt j tdt fddddd d d	 fd
ddt jdS )Nr&   r   ZBCELoss_no_reduce_scalarc                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:      r;   z/bceloss_no_reduce_scalar_test.<locals>.<lambda>r   c                   S   s   t dddS Nr&   r   r   r   r&   r&   r&   r'   r:      r;   rq   rr   c                    s"    |    d  d|       S rM   r   rO   rm   r&   r'   r:      r;   Frt   )r1   r|   r   r   rU   r}   rf   r&   r&   rm   r'   bceloss_no_reduce_scalar_test   s    

r   c                      st   t tjddtjddtj tjdtjdtdt fdddd	d d
 d fddddtjd	S )Nr   r,   dtyper   ZBCELoss_weights_no_reducec                    s   t j|  | | ddS Nrh   r)   r   r   rl   r3   r*   r&   r'   r:      s    z0bceloss_weights_no_reduce_test.<locals>.<lambda>zF::binary_cross_entropy(i, t.to(i.options()), F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))c                   S   s   t ddddS r   r   r&   r&   r&   r'   r:      r;   rq   r7   r3   r*   c                    s&    |    d  d|        S rM   r   )r7   r8   r$   r   r&   r'   r:      r;   Fa2U0*3?r   )	r   r1   r|   rU   r   r   rp   r}   rf   r&   r&   r   r'   bceloss_weights_no_reduce_test   s    $
r   c                
      sf   t ddt j t jdt jdtdt fdddd d	d
d  fdddt jdS )Nr&   r   r   Z BCELoss_weights_no_reduce_scalarc                    s   t j|  | | ddS r   r   rl   r   r&   r'   r:      s    z7bceloss_weights_no_reduce_scalar_test.<locals>.<lambda>zF::binary_cross_entropy(
            i, t.to(i.options()),
            F::BinaryCrossEntropyFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))rq   r   c                   S   s   t dddS r   r   r&   r&   r&   r'   r:      r;   c                    s&    |    d  d|        S rM   r   rO   r   r&   r'   r:      r;   F)ru   rv   rw   ry   rx   rA   rz   rD   )r1   r|   r   r   rU   rp   r}   rf   r&   r&   r   r'   %bceloss_weights_no_reduce_scalar_test   s    
r   c                      sb   t tdddtjt  tdt	fddddd d	d
 fddddtjd	S )Nr   r,   r   ZBCEWithLogitsLoss_legacy_enumc                    s   t j|  | ddS )NF)r   rj   Z binary_cross_entropy_with_logitsrk   rl   rm   r&   r'   r:     r;   z4bce_with_logistic_legacy_enum_test.<locals>.<lambda>F::binary_cross_entropy_with_logits(
            i, t.to(i.options()), F::BinaryCrossEntropyWithLogitsFuncOptions().reduction(torch::kNone))c                   S   s   t ddddS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s*    |    d d |       S rM   r   rO   sigmoidr3   r&   r'   r:     r;   F	ru   rv   rw   rx   ry   rA   rR   rz   rD   
r   r1   r|   r   r   rU   rd   Sigmoidr}   rf   r&   r&   r   r'   "bce_with_logistic_legacy_enum_test  s    
r   c                      sb   t tdddtjt  tdt	fddddd d	d
 fddddtjd	S )Nr   r,   r   ZBCEWithLogitsLoss_no_reducec                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:     r;   z2bce_with_logistic_no_reduce_test.<locals>.<lambda>r   c                   S   s   t ddddS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s*    |    d d |       S rM   r   rO   r   r&   r'   r:     r;   Fr   r   r&   r&   r   r'    bce_with_logistic_no_reduce_test  s    
r   c                      s\   t ddt jt  tdtfddddd dd	 fd
dddt jd	S )Nr&   r   Z"BCEWithLogitsLoss_no_reduce_scalarc                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:   ,  r;   z9bce_with_logistic_no_reduce_scalar_test.<locals>.<lambda>r   c                   S   s   t dddS r   r   r&   r&   r&   r'   r:   /  r;   rq   rr   c                    s*    |    d d |       S rM   r   rO   r   r&   r'   r:   1  r;   Fr   )	r1   r|   r   r   rU   rd   r   r}   rf   r&   r&   r   r'   'bce_with_logistic_no_reduce_scalar_test&  s    
r   c                      sL   t jddt jd tdt fddddd d d	 fd
dddt jd	S )Nr,   r   ZKLDivLoss_with_target_no_reducec                    s   t j|  | ddS rg   rj   Zkl_divrk   rl   rm   r&   r'   r:   =  r;   z6kldivloss_with_target_no_reduce_test.<locals>.<lambda>NF::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))c                   S   s   t dd S rn   r1   rp   r   r&   r&   r&   r'   r:   ?  r;   rq   rr   c                    s   t d |  | ddS N	KLDivLossrh   ri   loss_reference_fnsrk   rO   rm   r&   r'   r:   A  s    TF	ru   rv   rw   rx   ry   rA   supports_forward_adrz   rD   r1   rp   rU   r}   rf   r&   r&   rm   r'   $kldivloss_with_target_no_reduce_test8  s    

r   c                      sL   t jddt jd tdt fddddd d d	 fd
dddt jd	S )Nr,   r   ZKLDivLoss_no_reducec                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:   M  r;   z*kldivloss_no_reduce_test.<locals>.<lambda>r   c                   S   s   t dd S rn   r   r&   r&   r&   r'   r:   O  r;   rq   rr   c                    s   t d |  | ddS r   r   rO   rm   r&   r'   r:   Q  s    TFr   r   r&   r&   rm   r'   kldivloss_no_reduce_testH  s    

r   c                      sJ   t jdt jd tdt fddddd d d	 fd
dddt jd	S )Nr&   r   ZKLDivLoss_no_reduce_scalarc                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:   ^  r;   z1kldivloss_no_reduce_scalar_test.<locals>.<lambda>r   c                   S   s   t d S Nr&   r   r&   r&   r&   r'   r:   `  r;   rq   rr   c                    s   t d |  | ddS r   r   rO   rm   r&   r'   r:   b  s    TFr   r   r&   r&   rm   r'   kldivloss_no_reduce_scalar_testY  s    

r   c                      sP   t jddt jd  tdt fddddd d d	 fd
dddt jd	S )Nr,   r   Z#KLDivLoss_with_log_target_no_reducec                    s   t j|  | dddS Nrh   T)r   
log_targetr   rl   rm   r&   r'   r:   n  r;   z:kldivloss_with_log_target_no_reduce_test.<locals>.<lambda>_F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))c                   S   s   t dd S rn   r   r&   r&   r&   r'   r:   p  r;   rq   rr   c                    s   t d |  | ddS NKLDivLoss_log_targetrh   ri   r   rO   rm   r&   r'   r:   r  s    TFr   r1   rp   rU   r   r}   rf   r&   r&   rm   r'   (kldivloss_with_log_target_no_reduce_testi  s    

r   c                      sP   t jddt jd  tdt fddddd d d	 fd
dddt jd	S )Nr,   r   ZKLDivLoss_no_reduce_log_targetc                    s   t j|  | dddS r   r   rl   rm   r&   r'   r:   ~  r;   z5kldivloss_no_reduce_log_target_test.<locals>.<lambda>r   c                   S   s   t dd S rn   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | ddS r   r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   #kldivloss_no_reduce_log_target_testy  s    

r   c                      sN   t jdt jd  tdt fddddd d d	 fd
dddt jd	S )Nr&   r   Z%KLDivLoss_no_reduce_scalar_log_targetc                    s   t j|  | dddS r   r   rl   rm   r&   r'   r:     r;   z<kldivloss_no_reduce_scalar_log_target_test.<locals>.<lambda>r   c                   S   s   t d S r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | ddS r   r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   *kldivloss_no_reduce_scalar_log_target_test  s    

r   c                      sN   t jdddt jd tdt fdddd	d d
 d fddddt jd	S )NrG   rL   r.   r   ZL1Loss_no_reducec                    s   t j|  | ddS rg   rj   Zl1_lossrk   rl   rm   r&   r'   r:     r;   z'l1loss_no_reduce_test.<locals>.<lambda>PF::l1_loss(i, t.to(i.options()), F::L1LossFuncOptions().reduction(torch::kNone))c                   S   s   t dddS NrG   rL   r.   r1   r|   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   |   |   S rY   rk   absrO   rm   r&   r'   r:     r;   TFr   r1   r|   rU   r}   rf   r&   r&   rm   r'   l1loss_no_reduce_test  s    

r   c                
      sJ   t jdddt jd tdt fdddd	d d
 d fdddddS )NrG   rL   r.   r   ZL1Loss_no_reduce_complexc                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:     r;   z/l1loss_no_reduce_complex_test.<locals>.<lambda>r   c                   S   s   t jdddt jdS )NrG   rL   r.   r   )r1   r|   cdoubler&   r&   r&   r'   r:     r;   rq   rr   c                    s   |   |   S rY   r   rO   rm   r&   r'   r:     r;   TF)ru   rv   rw   rx   ry   rA   r   rz   )r1   r|   r   r}   rf   r&   r&   rm   r'   l1loss_no_reduce_complex_test  s    

r   c                      sJ   t jdt jd tdt fddddd d d	 fd
dddt jd	S )Nr&   r   ZL1Loss_no_reduce_scalarc                    s   t j|  | ddS rg   r   rl   rm   r&   r'   r:     r;   z.l1loss_no_reduce_scalar_test.<locals>.<lambda>r   c                   S   s
   t dS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   |   |   S rY   r   rO   rm   r&   r'   r:     r;   TFr   r   r&   r&   rm   r'   l1loss_no_reduce_scalar_test  s    

r   c                     sL   d} t j| dt ji tdt fddd| d d fd	dd
dt jd	S )NrK   r   ZMSELoss_no_reducec                    s   t j|  | ddS rg   rj   Zmse_lossrk   rl   targetr&   r'   r:     r;   z(mseloss_no_reduce_test.<locals>.<lambda>WF::mse_loss(i, target.to(i.options()), F::MSELossFuncOptions().reduction(torch::kNone))rq   r7   r   c                    s   |    dS NrG   powrO   r   r&   r'   r:     r;   TF	ru   rv   rw   r@   ry   rA   r   rz   rD   r   r@   r&   r   r'   mseloss_no_reduce_test  s    

r   c                     sJ   d} t j| t jd tdt fddd| d d fd	dd
dt jd	S )Nr&   r   ZMSELoss_no_reduce_scalarc                    s   t j|  | ddS rg   r   rl   r   r&   r'   r:     r;   z/mseloss_no_reduce_scalar_test.<locals>.<lambda>r   rq   r   c                    s   |    dS r   r   rO   r   r&   r'   r:     r;   TFr   r   r   r&   r   r'   mseloss_no_reduce_scalar_test  s    

r   c                
      sd   t td d  ddi tdt fdddd	d d
d fdddtj	dS )Nr   r,   r   rh   ZNLLLoss_no_reducec                    s   t j| |   d dS Nr   ri   rj   nll_lossrk   longrl   r^   r3   r&   r'   r:     r;   z(nllloss_no_reduce_test.<locals>.<lambda>pF::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().reduction(torch::kNone))c                   S   s   t dd S Nr   r,   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d | |  f S NNLLLossr   rk   r   rO   r   r&   r'   r:     s    Frt   
r   r1   emptyuniform_r   floorr   r}   rf   rU   r&   r&   r   r'   nllloss_no_reduce_test  s     r   c                
      sf   t td d  ddd tdt fddd	d
d dd fdddtj	dS )Nr   r,   rG   rh   ignore_indexr   ZNLLLoss_no_reduce_ignore_indexc                    s,   t j| |  t d t d dS Nr   r   r   rj   r   rk   r   intstrrl   r   r&   r'   r:     s   
z5nllloss_no_reduce_ignore_index_test.<locals>.<lambda>zF::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(2).reduction(torch::kNone))c                   S   s   t dd S r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d | |  f S r   r   rO   r   r&   r'   r:     s    Frt   r   r&   r&   r   r'   #nllloss_no_reduce_ignore_index_test  s     
r   c                
      st   t td d  tdfdd tdt	 fdddd	d d
d fdddtj
dS )Nr   r,   c                    s     | ddS r   rk   rl   r)   r&   r'   r^     s    z.nllloss_no_reduce_weights_test.<locals>.kwargsZNLLLoss_no_reduce_weightsc                    s   t j| |  f | S rY   r   rl   r   r&   r'   r:     r;   z0nllloss_no_reduce_weights_test.<locals>.<lambda>F::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone))c                   S   s   t ddd S Nr   r,   {Gz?r1   rp   addr   r&   r&   r&   r'   r:     r;   rq   r7   r3   r)   c                    s    t d | |  f | S r   r   rO   r   r&   r'   r:     s    Frt   r   r1   r   r   r   r   r   rp   r}   rf   rU   r&   r&   r^   r3   r)   r'   nllloss_no_reduce_weights_test	  s     

r   c                
      st   t td d  tdfdd tdt	 fdddd	d d
d fdddtj
dS )Nr   r,   c                    s     | dddS )Nrh   rG   r)   r   r   r   rl   r   r&   r'   r^   #  s    
z;nllloss_no_reduce_weights_ignore_index_test.<locals>.kwargsZ&NLLLoss_no_reduce_weights_ignore_indexc                    s    t j| |  f | jS rY   )rj   r   rk   r   datarl   r   r&   r'   r:   *  r;   z=nllloss_no_reduce_weights_ignore_index_test.<locals>.<lambda>zF::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(2))c                   S   s   t ddd S r   r   r&   r&   r&   r'   r:   .  r;   rq   r   c                    s    t d | |  f | S r   r   rO   r   r&   r'   r:   0  s    Frt   r   r&   r&   r   r'   +nllloss_no_reduce_weights_ignore_index_test  s     

r   c                
      s   t td d  tdfdd tdt	 fdddtjddtj
d	d
 dd fdddtj
dS )Nr   r,   c                    s     | dddS )Nrh   r0   r   r   rl   r   r&   r'   r^   :  s    
z?nllloss_no_reduce_weights_ignore_index_neg_test.<locals>.kwargsZ*NLLLoss_no_reduce_weights_ignore_index_negc                    s   t j| |  f | S rY   r   rl   r   r&   r'   r:   A  r;   zAnllloss_no_reduce_weights_ignore_index_neg_test.<locals>.<lambda>zF::nll_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::NLLLossFuncOptions().weight(weight.to(i.options())).reduction(torch::kNone).ignore_index(-1))r   r   rq   r   c                    s    t d | |  f | S r   r   rO   r   r&   r'   r:   G  s    F)ru   rv   rw   inputry   rA   rz   rD   )r   r1   r   r   r   r   r   rp   r}   rf   rU   r   r   r&   r&   r   r'   /nllloss_no_reduce_weights_ignore_index_neg_test6  s     

r   c                
      sd   t tdddd  ddi tdt fddd	d
d dd fdddtjdS )NrG   rJ   rL   r   rh   ZNLLLoss2d_no_reducec                    s   t j| |   d dS r   r   rl   r   r&   r'   r:   S  r;   z*nllloss2d_no_reduce_test.<locals>.<lambda>r   c                   S   s   t dddd S NrG   rL   rJ   r   r&   r&   r&   r'   r:   V  r;   rq   rr   c                    s   t d | |  f S N	NLLLossNdr   rO   r   r&   r'   r:   X  s    Frt   	r   r1   rp   r   r   r   r}   rf   rU   r&   r&   r   r'   nllloss2d_no_reduce_testM  s     r   c                
      sf   t tdddd  ddd tdt fdd	d
dd	 dd fdd	dtjdS )NrG   rJ   rL   r/   rh   r   Z NLLLoss2d_no_reduce_ignore_indexc                    s,   t j| |  t d t d dS r   r   rl   r   r&   r'   r:   d  s   
z7nllloss2d_no_reduce_ignore_index_test.<locals>.<lambda>F::nll_loss(
            i, t.to(i.options()).to(torch::kLong), F::NLLLossFuncOptions().ignore_index(1).reduction(torch::kNone))c                   S   s   t dddd S r   r   r&   r&   r&   r'   r:   h  r;   rq   rr   c                    s   t d | |  f S r   r   rO   r   r&   r'   r:   j  s    Frt   r   r&   r&   r   r'   %nllloss2d_no_reduce_ignore_index_test^  s     
r   c                
      st   t tdddd  tdfdd tdt fddd	d
d dd fdddtjdS )NrG   rJ   rL   c                    s     | ddS r   r   rl   r   r&   r'   r^   t  s    z0nllloss2d_no_reduce_weights_test.<locals>.kwargsZNLLLoss2d_no_reduce_weightsc                    s   t j| |  f | S rY   r   rl   r   r&   r'   r:   z  r;   z2nllloss2d_no_reduce_weights_test.<locals>.<lambda>r   c                   S   s   t dddd S r   r   r&   r&   r&   r'   r:   ~  r;   rq   r   c                    s    t d | |  f | S r   r   rO   r   r&   r'   r:     s    Frt   r   r&   r&   r   r'    nllloss2d_no_reduce_weights_testp  s     

r   c                
      sh   t tdddddd  ddi tdt fddd	d
d dd fdddtjdS )NrG   rJ   rL   r   rh   ZNLLLossNd_no_reducec                    s   t j| |   d dS r   r   rl   r   r&   r'   r:     r;   z*nlllossNd_no_reduce_test.<locals>.<lambda>r   c                   S   s   t dddddd S r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d | |  f S r   r   rO   r   r&   r'   r:     s    Frt   r   r&   r&   r   r'   nlllossNd_no_reduce_test  s    $r   c                
      sj   t tdddddd  ddd tdt fdd	d
dd	 dd fdd	dtjdS )NrG   rJ   rL   r/   rh   r   Z NLLLossNd_no_reduce_ignore_indexc                    s,   t j| |  t d t d dS r   r   rl   r   r&   r'   r:     s   
z7nlllossNd_no_reduce_ignore_index_test.<locals>.<lambda>r   c                   S   s   t dddddd S r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d | |  f S r   r   rO   r   r&   r'   r:     s    Frt   r   r&   r&   r   r'   %nlllossNd_no_reduce_ignore_index_test  s    $
r   c                
      sx   t tdddddd  tdfdd tdt fddd	d
d dd fdddtjdS )NrG   rJ   rL   c                    s     | ddS r   r   rl   r   r&   r'   r^     s    z0nlllossNd_no_reduce_weights_test.<locals>.kwargsZNLLLossNd_no_reduce_weightsc                    s   t j| |  f | S rY   r   rl   r   r&   r'   r:     r;   z2nlllossNd_no_reduce_weights_test.<locals>.<lambda>r   c                   S   s   t dddddd S r   r   r&   r&   r&   r'   r:     r;   rq   r   c                    s    t d | |  f | S r   r   rO   r   r&   r'   r:     s    Frt   r   r&   r&   r   r'    nlllossNd_no_reduce_weights_test  s    $

r  c                      sN   t jdddt jd tdt fdddd	d d
 d fddddt jd	S )NrG   rL   r.   r   ZSmoothL1Loss_no_reducec                    s   t j|  | ddS rg   rj   Zsmooth_l1_lossrk   rl   rm   r&   r'   r:     r;   z-smoothl1loss_no_reduce_test.<locals>.<lambda>jF::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone))c                   S   s   t dddS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | ddS NSmoothL1Lossrh   ri   r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   smoothl1loss_no_reduce_test  s    

r  c                      sJ   t jdt jd tdt fddddd d d	 fd
dddt jd	S )Nr&   r   ZSmoothL1Loss_no_reduce_scalarc                    s   t j|  | ddS rg   r  rl   rm   r&   r'   r:     r;   z4smoothl1loss_no_reduce_scalar_test.<locals>.<lambda>r  c                   S   s
   t dS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | ddS r  r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   "smoothl1loss_no_reduce_scalar_test  s    

r  c                      sN   t jdddt jd tdt fdddd	d d
 d fddddt jd	S )NrG   rL   r.   r   ZSmoothL1Loss_betac                    s   t j|  | dddS )Nrh         ?r   betar  rl   rm   r&   r'   r:     r;   z(smoothl1loss_beta_test.<locals>.<lambda>zoF::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0.5)c                   S   s   t dddS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | dddS )Nr  rh   r  r	  r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   smoothl1loss_beta_test  s    

r  c                      sN   t jdddt jd tdt fdddd	d d
 d fddddt jd	S )NrG   rL   r.   r   ZSmoothL1Loss_zero_betac                    s   t j|  | dddS )Nrh   r   r	  r  rl   rm   r&   r'   r:     r;   z-smoothl1loss_zero_beta_test.<locals>.<lambda>zmF::smooth_l1_loss(
            i, t.to(i.options()), F::SmoothL1LossFuncOptions().reduction(torch::kNone), 0)c                   S   s   t dddS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | dddS )Nr  rh   r   r	  r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   smoothl1loss_zero_beta_test  s    

r  c                      sH   t ddd tdt fddddd d	 d
 fddddt jd	S )NrG   rL   r.   ZHuberLoss_deltac                    s   t j|  | dddS )Nrh   r  r   delta)rj   Z
huber_lossrk   rl   rm   r&   r'   r:     r;   z&huberloss_delta_test.<locals>.<lambda>znF::huber_loss(
            i, t.to(i.options()), F::HuberLossFuncOptions().reduction(torch::kNone).delta(0.5))c                   S   s   t dddS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | dddS )N	HuberLossrh   r  r  r   rO   rm   r&   r'   r:     s    TFr   r{   r&   r&   rm   r'   huberloss_delta_test  s    

r  c                      sF   t d  tdt fddddd d d fd	dd
ddd	S )Nr&   Z!MultiLabelMarginLoss_0d_no_reducec                    s   t j|  |  ddS rg   rj   Zmultilabel_margin_lossrk   r   rl   rm   r&   r'   r:     r;   z8multilabelmarginloss_0d_no_reduce_test.<locals>.<lambda>F::multilabel_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultilabelMarginLossFuncOptions().reduction(torch::kNone))c                   S   s
   t dS r   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  j|  ddS NMultiLabelMarginLossrh   ri   r   r   rk   r   rO   rm   r&   r'   r:     s    TF)	ru   rv   rw   rx   ry   rA   check_sum_reductionrR   rz   )r1   zerosr   r}   rf   r&   r&   rm   r'   &multilabelmarginloss_0d_no_reduce_test  s    

r  c                      sX   t tdd   tdt fddddd d d fd	dd
ddtjd
S )Nr,   Z!MultiLabelMarginLoss_1d_no_reducec                    s   t j|  |  ddS rg   r  rl   rm   r&   r'   r:   *  r;   z8multilabelmarginloss_1d_no_reduce_test.<locals>.<lambda>r  c                   S   s
   t dS rn   r   r&   r&   r&   r'   r:   -  r;   rq   rr   c                    s   t d |  j|  ddS r  r  rO   rm   r&   r'   r:   /  s    TF
ru   rv   rw   rx   ry   rA   r  rR   rz   rD   r   r&   r&   rm   r'   &multilabelmarginloss_1d_no_reduce_test%  s    

r  c                      sj   t tjtdddd  dd tdt	 fdd	d
dd	 d d fdd	dddtj
d
S )NrJ   r,   g         r0   minZMultiLabelMarginLoss_index_negc                    s   t j|  |  ddS rg   r  rl   rm   r&   r'   r:   <  r;   z5multilabelmarginloss_index_neg_test.<locals>.<lambda>r  c                   S   s   t ddS NrJ   r,   r   r&   r&   r&   r'   r:   ?  r;   rq   rr   c                    s   t d |  j|  ddS r  r  rO   rm   r&   r'   r:   A  s    TFr  )r   r1   clamprp   r   r   r   r   r}   rf   rU   r&   r&   rm   r'   #multilabelmarginloss_index_neg_test7  s    .

r   c                      sZ   t tddd   tdt fddddd d d	 fd
ddddtjd
S )NrJ   r,   ZMultiLabelMarginLoss_no_reducec                    s   t j|  |  ddS rg   r  rl   rm   r&   r'   r:   N  r;   z5multilabelmarginloss_no_reduce_test.<locals>.<lambda>r  c                   S   s   t ddS r  r   r&   r&   r&   r'   r:   Q  r;   rq   rr   c                    s   t d |  j|  ddS r  r  rO   rm   r&   r'   r:   S  s    TFr  r   r&   r&   rm   r'   #multilabelmarginloss_no_reduce_testI  s    

r!  c                      sb   t tddtjdd tdt	 fdddd	d d
 d fddddtjd	S )Nr,   r   rG   r/   ZHingeEmbeddingLoss_no_reducec                    s   t j|  | ddS rg   rj   Zhinge_embedding_lossrk   rl   rm   r&   r'   r:   `  r;   z3hingeembeddingloss_no_reduce_test.<locals>.<lambda>zvF::hinge_embedding_loss(
            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().reduction(torch::kNone))c                   S   s
   t dS rn   r   r&   r&   r&   r'   r:   c  r;   rq   rr   c                    s   t d |  | ddS )NHingeEmbeddingLossrh   ri   r   rO   rm   r&   r'   r:   e  s    TF	ru   rv   rw   rx   ry   rA   r  rz   rD   
r   r1   r|   r   r   rU   Zmul_subr}   rf   r&   r&   rm   r'   !hingeembeddingloss_no_reduce_test[  s    (

r'  c                      sb   t tddtjdd tdt	 fdddd	d d
 d fddddtjd	S )Nr,   r   rG   r/   Z#HingeEmbeddingLoss_margin_no_reducec                    s   t j|  | dddS Nr  rh   marginr   r"  rl   rm   r&   r'   r:   q  r;   z:hingeembeddingloss_margin_no_reduce_test.<locals>.<lambda>zF::hinge_embedding_loss(
            i, t.to(i.options()), F::HingeEmbeddingLossFuncOptions().margin(0.5).reduction(torch::kNone))c                   S   s
   t dS rn   r   r&   r&   r&   r'   r:   t  r;   rq   rr   c                    s   t d |  | dddS )Nr#  r  rh   r)  r   rO   rm   r&   r'   r:   v  s    TFr$  r%  r&   r&   rm   r'   (hingeembeddingloss_margin_no_reduce_testl  s    (

r+  c                      sL   t jddt jd tdt fddddd d d	 fd
dddt jd	S )NrJ   r   ZSoftMarginLoss_no_reducec                    s   t j|  | ddS rg   )rj   Zsoft_margin_lossrk   rl   rm   r&   r'   r:     r;   z/softmarginloss_no_reduce_test.<locals>.<lambda>znF::soft_margin_loss(
            i, t.to(i.options()), F::SoftMarginLossFuncOptions().reduction(torch::kNone))c                   S   s   t ddS )NrJ   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  | ddS )NSoftMarginLossrh   ri   r   rO   rm   r&   r'   r:     s    TFr   r   r&   r&   rm   r'   softmarginloss_no_reduce_test}  s    

r-  c                      sP   t ddd  tdt fddddd d	 d
 fddddt jd	S )NrJ   r,   rG   Z"MultiLabelSoftMarginLoss_no_reducec                    s   t j|  | ddS rg   rj   Zmultilabel_soft_margin_lossrk   rl   rm   r&   r'   r:     r;   z9multilabelsoftmarginloss_no_reduce_test.<locals>.<lambda>zF::multilabel_soft_margin_loss(
            i, t.to(i.options()), F::MultilabelSoftMarginLossFuncOptions().reduction(torch::kNone))c                   S   s   t ddS r  r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s:    |     d  |        jdd| d S Nr/   dimr   r   sumrV   rO   rm   r&   r'   r:     s    Fr   r1   rp   r   r   r}   rf   rU   r&   r&   rm   r'   'multilabelsoftmarginloss_no_reduce_test  s    

r5  c                      sb   t ddd  t dtdt fddddd d	 d
 fdddddt jd
S )NrJ   r,   rG   Z*MultiLabelSoftMarginLoss_weights_no_reducec                    s   t j|  | | ddS r   r.  rl   r   r&   r'   r:     s    zAmultilabelsoftmarginloss_weights_no_reduce_test.<locals>.<lambda>zF::multilabel_soft_margin_loss(
            i, t.to(i.options()),
            F::MultilabelSoftMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))c                   S   s   t ddS r  r   r&   r&   r&   r'   r:     r;   rq   r   c                    s>    |     d  |         jdd| d S r/  r2  rO   r   r&   r'   r:     s    TFr  r4  r&   r&   r   r'   /multilabelsoftmarginloss_weights_no_reduce_test  s     

r6  c                      sT   t dd   tdt fddddd d d	 fd
ddddt jd
S )NrJ   r-   ZMultiMarginLoss_no_reducec                    s   t j|  |  ddS rg   rj   Zmulti_margin_lossrk   r   rl   rm   r&   r'   r:     r;   z0multimarginloss_no_reduce_test.<locals>.<lambda>F::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().reduction(torch::kNone))c                   S   s   t ddS r  r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  j|  ddS NMultiMarginLossrh   ri   r  rO   rm   r&   r'   r:     s    TFr  r1   rp   r   r   r   r}   rf   rU   r&   r&   rm   r'   multimarginloss_no_reduce_test  s    

r<  c                      sT   t dd   tdt fddddd d d	 fd
ddddt jd
S )Nr/   r-   ZMultiMarginLoss_1d_no_reducec                    s   t j|  |  ddS rg   r7  rl   rm   r&   r'   r:     r;   z3multimarginloss_1d_no_reduce_test.<locals>.<lambda>r8  c                   S   s
   t dS rn   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  j|  ddS r9  r  rO   rm   r&   r'   r:     s    TFr  r;  r&   r&   rm   r'   !multimarginloss_1d_no_reduce_test  s    

r=  c                      sT   t dd   tdt fddddd d d	 fd
ddddt jd
S )Nr&   r-   Z,multimarginloss_1d_input_0d_target_no_reducec                    s   t j|  |  ddS rg   r7  rl   rm   r&   r'   r:     r;   zCmultimarginloss_1d_input_0d_target_no_reduce_test.<locals>.<lambda>r8  c                   S   s
   t dS rn   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s   t d |  j|  ddS r9  r  rO   rm   r&   r'   r:     s    TFr  r;  r&   r&   rm   r'   1multimarginloss_1d_input_0d_target_no_reduce_test  s    

r>  c                      sT   t dd   tdt fddddd d d	 fd
ddddt jd
S )NrJ   r-   ZMultiMarginLoss_p_no_reducec                    s   t j|  |  dddS )NrG   rh   r8   r   r7  rl   rm   r&   r'   r:     r;   z2multimarginloss_p_no_reduce_test.<locals>.<lambda>zF::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong), F::MultiMarginLossFuncOptions().p(2).reduction(torch::kNone))c                   S   s   t ddddS )NrJ   r,   r   gGz?)r1   r|   r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s    t d |  j|  dddS )Nr:  rG   rh   r?  r  rO   rm   r&   r'   r:     s    TFr  r;  r&   r&   rm   r'    multimarginloss_p_no_reduce_test  s    

r@  c                      sT   t dd   tdt fddddd d d	 fd
ddddt jd
S )NrJ   r-   Z MultiMarginLoss_margin_no_reducec                    s   t j|  |  dddS r(  r7  rl   rm   r&   r'   r:     r;   z7multimarginloss_margin_no_reduce_test.<locals>.<lambda>zF::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::MultiMarginLossFuncOptions().margin(0.5).reduction(torch::kNone))c                   S   s   t ddS r  r   r&   r&   r&   r'   r:     r;   rq   rr   c                    s    t d |  j|  dddS )Nr:  r  rh   r)  r  rO   rm   r&   r'   r:     s     TFr  r;  r&   r&   rm   r'   %multimarginloss_margin_no_reduce_test  s    

rA  c                      sj   t dd   t jdt jdtdt fdddd	d d
 d fdddddt jd
S )NrJ   r-   r,   r   Z!MultiMarginLoss_weights_no_reducec                    s"   t j|  |  | ddS r   r7  rl   r   r&   r'   r:     s   z8multimarginloss_weights_no_reduce_test.<locals>.<lambda>zF::multi_margin_loss(
            i, t.to(i.options()).to(torch::kLong),
            F::MultiMarginLossFuncOptions().weight(weights.to(i.options())).reduction(torch::kNone))c                   S   s   t ddS r  r   r&   r&   r&   r'   r:     r;   rq   r   c                    s    t d |  j|  ddS )Nr:  rh   r   r  rO   r   r&   r'   r:     s     TFr  )r1   rp   r   r   r   rU   r}   rf   r&   r&   r   r'   &multimarginloss_weights_no_reduce_test  s     
rB  c              
   C   sR   dd }|| }t |tjr"|gn|}t  || dW  5 Q R  S Q R X dS )zReference function for modules supporting no batch dimensions.

    The module is passed the input and target in batched form with a single item.
    The output is squeezed to compare with the no-batch input.
    c                 S   s&   t | ttfrdd | D S | dS )Nc                 S   s   g | ]}| d qS r   	unsqueeze.0r3   r&   r&   r'   
<listcomp>.  s     zDsingle_batch_reference_fn.<locals>.unsqueeze_inp.<locals>.<listcomp>r   
isinstancelisttuplerE  inpr&   r&   r'   unsqueeze_inp,  s    z0single_batch_reference_fn.<locals>.unsqueeze_inpr   N)rJ  r1   Tensorr   squeeze)r   
parametersmodulerO  Zsingle_batch_inputr&   r&   r'   single_batch_reference_fn&  s
    rT  Conv1d)r.   rJ   rL   z!torch::nn::Conv1dOptions(4, 5, 3))rG   r.   r,   )r=   r>   r?   r@   cudnnrB   rC   rD   )r.   rJ   rL   rG   z+torch::nn::Conv1dOptions(4, 5, 3).stride(2)stride)	r=   r>   r?   r@   rV  rE   rB   rC   rD   )r.   rJ   rL   r/   r/   z6torch::nn::Conv1dOptions(4, 5, 3).stride(1).padding(1)Zpad1r   )r.   rJ   rJ   r/   rG   z6torch::nn::Conv1dOptions(4, 5, 5).stride(1).padding(2)Zpad2)r.   r.   rL   r/   r/   z6torch::nn::Conv1dOptions(4, 4, 3).stride(1).padding(1))r/   r.   r/   Z	pad1size1)r.   r.   rJ   r/   rG   z6torch::nn::Conv1dOptions(4, 4, 5).stride(1).padding(2)Z	pad2size1)r   r.   r,   Z
zero_batch)r=   r>   r?   r@   rV  rE   rB   rC   ZConv1d_dilatedc                   C   s   t jdddddS )Nr.   rJ   rL   rG   kernel_sizedilationrd   rU  r&   r&   r&   r'   r:     r;   z-torch::nn::Conv1dOptions(4, 5, 3).dilation(2))ru   rv   r?   r@   rB   rC   rD   ZConv1d_groupsc                   C   s   t jdddddS )Nr.   rQ   rL   rG   rY  groupsr[  r&   r&   r&   r'   r:     r;   z+torch::nn::Conv1dOptions(4, 6, 3).groups(2))rG   r.   rQ   )ru   rv   r?   r@   rV  rB   rC   rD   ZConv1d_pad_validc                   C   s   t jdddddS )Nr.   rJ   rL   validpaddingr[  r&   r&   r&   r'   r:     r;   z8torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kValid)ZConv1d_pad_samec                   C   s   t jdddddS )Nr.   rJ   rL   samer_  r[  r&   r&   r&   r'   r:     r;   z7torch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame)ZConv1d_pad_same2c                   C   s   t jdddddS )Nr.   rJ   ra  r_  r[  r&   r&   r&   r'   r:     r;   z7torch::nn::Conv1dOptions(4, 5, 4).padding(torch::kSame)ZConv1d_pad_same_dilatedc                   C   s   t jddddddS )Nr.   rJ   ra  rG   r`  rZ  r[  r&   r&   r&   r'   r:     r;   zCtorch::nn::Conv1dOptions(4, 5, 3).padding(torch::kSame).dilation(2)ConvTranspose1dc                   C   s   t jdddddddS )NrL   r.   rL   r/   r/   )rY  rW  r`  output_paddingrd   rc  r&   r&   r&   r'   r:     r;   zQtorch::nn::ConvTranspose1dOptions(3, 4, 3).stride(3).padding(1).output_padding(1))r/   rL      )ru   rv   r?   rV  r@   rB   rC   rD   )rL   r.   rL   rG   r/   r/   r/   Fztorch::nn::ConvTranspose1dOptions(3, 4, 3)
                                .stride(2).padding(1).output_padding(1).groups(1).bias(false))r/   rL   rQ   )	rL   r.   rL   rG   r/   r/   r/   TrG   ztorch::nn::ConvTranspose1dOptions(3, 4, 3)
                                .stride(2).padding(1).output_padding(1).groups(1).bias(true).dilation(2)ZdilatedZConvTranspose1d_groupsc                	   C   s   t jddddddddS )	Nr.   rQ   rL   rd  r/   re  rG   )rW  r`  rf  r]  rg  r&   r&   r&   r'   r:     r;   z|torch::nn::ConvTranspose1dOptions(4, 6, 3)
                                .stride(3).padding(1).output_padding(1).groups(2))rG   r.   rh  Conv2d)rL   r.   rL   rG   z&torch::nn::Conv2dOptions(3, 4, {3, 2}))rG   rL   rh  rJ   )	r=   r>   r?   r@   rV  check_with_long_tensorrB   rC   rD   )rL   r.   rL   rL   rG   rG   z5torch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2})Zstrided)
r=   r>   r?   r@   rV  rE   rk  rB   rC   rD   )rL   r.   rl  rm  r/   r/   zEtorch::nn::Conv2dOptions(3, 4, {3, 3}).stride({2, 2}).padding({1, 1})r`  )rL   rG   rl  rm  rn  rm  zVtorch::nn::Conv2dOptions(3, 2, {3, 3}).stride({2, 2}).padding({1, 1}).dilation({2, 2}))rG   rL   r-   r-   )rL   r.   rj  r/   r   r/   r/   Fz~torch::nn::Conv2dOptions(3, 4, {3, 2})
                                .stride(1).padding(0).dilation(1).groups(1).bias(false))rG   rL   rQ   rJ   gQ?)r   rL   rh  rJ   )r=   r>   r?   r@   rV  rE   rk  rB   ZConv2d_groupsc                   C   s   t jdddddS Nr.   rQ   rj  rG   r]  rd   ri  r&   r&   r&   r'   r:   e  r;   z0torch::nn::Conv2dOptions(4, 6, {3, 2}).groups(2))rG   r.   rQ   rJ   )	ru   rv   r?   r@   rV  rk  rB   rC   rD   ZConv2d_groups_thnnc                   C   s   t jdddddS ro  rq  r&   r&   r&   r'   r:   p  r;   )ru   rv   r?   r@   rk  rB   rC   rD   ZConv2d_pad_validc                   C   s   t jdddddS )NrG   r.   rL   r.   r^  r_  rq  r&   r&   r&   r'   r:   z  r;   z=torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kValid))rG   rG   rQ   rJ   ZConv2d_pad_samec                   C   s   t jdddddS )NrG   r.   rr  ra  r_  rq  r&   r&   r&   r'   r:     r;   z<torch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame)ZConv2d_pad_same_dilatedc                   C   s   t jddddddS )NrG   r.   rr  ra  rb  rq  r&   r&   r&   r'   r:     r;   zHtorch::nn::Conv2dOptions(2, 4, {3, 4}).padding(torch::kSame).dilation(2)ConvTranspose2d)rL   r.   rL   rj  r/   rn  z|torch::nn::ConvTranspose2dOptions(3, 4, 3)
                                .stride({3, 2}).padding(1).output_padding({1, 1}))r/   rL   rh  rQ   )	r=   r>   r?   rV  r@   rk  rB   rC   rD   )	rL   r.   rL   rG   rL   r/   rn  r/   Frm  aH  torch::nn::ConvTranspose2dOptions(3, 4, 3)
                                .stride({2, 3})
                                .padding(1)
                                .output_padding({1, 1})
                                .groups(1)
                                .bias(false)
                                .dilation({2, 2}))r/   rL   rQ   rh  )rL   r.   rL   rt  r/   rn  r/   Fztorch::nn::ConvTranspose2dOptions(3, 4, 3)
                                .stride({2, 3}).padding(1).output_padding({1, 1}).groups(1).bias(false)ZConvTranspose2d_groupsc                   C   s   t jdddddS )NrG   r.   rt  rp  )rd   rs  r&   r&   r&   r'   r:     r;   z9torch::nn::ConvTranspose2dOptions(2, 4, {2, 3}).groups(2))r/   rG   r.   rJ   ZConv2d_depthwisec                   C   s   t jdddddS )Nr.   rl  rp  rq  r&   r&   r&   r'   r:     r;   z0torch::nn::Conv2dOptions(4, 4, {3, 3}).groups(4))rG   r.   rQ   rQ   Z Conv2d_depthwise_with_multiplierc                   C   s   t jdddddS )Nr.   r-   rl  rp  rq  r&   r&   r&   r'   r:     r;   z0torch::nn::Conv2dOptions(4, 8, {3, 3}).groups(4)ZConv2d_depthwise_stridedc                   C   s   t jddddddS )Nr.   rl  rm  )rW  r]  rq  r&   r&   r&   r'   r:     r;   z?torch::nn::Conv2dOptions(4, 4, {3, 3}).stride({2, 2}).groups(4)ZConv2d_depthwise_paddedc                   C   s   t jddddddS )Nr.   rl  rn  )r`  r]  rq  r&   r&   r&   r'   r:     r;   z@torch::nn::Conv2dOptions(4, 4, {3, 3}).padding({1, 1}).groups(4)ZConv2d_depthwise_dilatedc                   C   s   t jddddddS )Nr.   rm  )rZ  r]  rq  r&   r&   r&   r'   r:     r;   zAtorch::nn::Conv2dOptions(4, 4, {2, 2}).dilation({2, 2}).groups(4))rG   r.   rJ   rJ   Conv3d)rG   rL   rG   rL   rG   z)torch::nn::Conv3dOptions(2, 3, {2, 3, 2}))r/   rG   r.   rJ   r.   g?)rG   rL   rG   rL   r.   r/   r   r/   r/   Fztorch::nn::Conv3dOptions(2, 3, {2, 3, 4})
                                .stride(1).padding(0).dilation(1).groups(1).bias(false))r/   rG   rL   r.   rJ   )rG   rL   )r/   r/   r/   r/   r   r/   r/   FZ1x1x1_no_bias)rL   r.   rG   rG   z+torch::nn::Conv3dOptions(3, 4, 2).stride(2))rG   rL   rJ   rJ   rJ   )rL   r.   rG   rG   r/   z6torch::nn::Conv3dOptions(3, 4, 2).stride(2).padding(1)Zstride_padding)rL   r.   rw  z)torch::nn::Conv3dOptions(3, 4, {2, 3, 4}))r   rL   rL   r.   rJ   )r=   r>   r?   r@   rV  rk  rE   rB   ZConv3d_groupsc                   C   s   t jdddddS )NrG   r.   rL   r\  rd   ru  r&   r&   r&   r'   r:   B  r;   z+torch::nn::Conv3dOptions(2, 4, 3).groups(2)ZConv3d_dilatedc                   C   s   t jdddddS )NrL   r.   rG   rX  rx  r&   r&   r&   r'   r:   M  r;   z-torch::nn::Conv3dOptions(3, 4, 2).dilation(2)ZConv3d_dilated_stridedc                   C   s   t jddddddS )NrL   r.   rG   )rY  rZ  rW  rx  r&   r&   r&   r'   r:   V  r;   z7torch::nn::Conv3dOptions(3, 4, 2).dilation(2).stride(2)ZConv3d_pad_validc                   C   s   t jdddddS )NrL   r.   rw  r^  r_  rx  r&   r&   r&   r'   r:   _  r;   z@torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kValid))rG   rL   rQ   rJ   r.   ZConv3d_pad_samec                   C   s   t jdddddS )NrL   r.   rw  ra  r_  rx  r&   r&   r&   r'   r:   i  r;   z?torch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame)ZConv3d_pad_same_dilatedc                   C   s   t jddddddS )NrL   r.   rw  ra  rG   rb  rx  r&   r&   r&   r'   r:   s  r;   zKtorch::nn::Conv3dOptions(3, 4, {2, 3, 4}).padding(torch::kSame).dilation(2)ZConvTranspose3dz2torch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2}))r=   r>   r?   rV  r@   rB   rC   rD   )	rG   rL   rv  r/   r   r   r/   T)rG   rG   rG   ztorch::nn::ConvTranspose3dOptions(2, 3, {2, 3, 2})
                                .stride(1).padding(0).output_padding(0).groups(1).bias(true).dilation({2, 2, 2}))	r=   r>   r?   rV  r@   rE   rB   rC   rD   ZReplicationPad3d))r/   rG   rL   rL   rG   r/   z6torch::nn::ReplicationPad3dOptions({1, 2, 3, 3, 2, 1}))rG   rL   rG   rG   rG   )r=   r>   r?   r@   rD   )rL   rG   rG   rG   Zno_batch_dim)r=   r>   r?   r@   rA   rE   rD   c                	   C   s   t jdddddt jddS )NrG   rL   Tr   requires_grad)r1   rp   
complex128r&   r&   r&   r'   r:     r;   complex)r=   r>   r?   rx   	skip_halfrE   	Embedding)r.   rL   z!torch::nn::EmbeddingOptions(4, 3)c                   C   s   t jddt jddS NrG   rL   r   r.   r1   r   r   random_r&   r&   r&   r'   r:     r;   z0https://github.com/pytorch/pytorch/issues/117971)r=   r>   r?   rx   rR   rD   	decoratorc                   C   s    t jddt jddddS Nr/   i   r   r.   rh  r1   r   r   r  r5   r&   r&   r&   r'   r:     r;   Zdiscontiguous)r=   r>   r?   rx   rR   rE   rD   r  EmbeddingBagz$torch::nn::EmbeddingBagOptions(4, 3)c                   C   s   t jddt jddS r  r  r&   r&   r&   r'   r:     r;   mean)r=   r>   r?   rx   rR   rE   rD   c                   C   s    t jddt jddddS r  r  r&   r&   r&   r'   r:     r;   )r.   rL   N       @Fr3  ztorch::nn::EmbeddingBagOptions(4, 3)
                                .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum)c                   C   s   t jddt jddS r  r  r&   r&   r&   r'   r:     r;   r3  )r.   rL   Nr  Fmaxztorch::nn::EmbeddingBagOptions(4, 3)
                                .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax)c                   C   s   t jddt jddS r  r  r&   r&   r&   r'   r:     r;   r  ZEmbeddingBag_mean_padding_idxc                   C   s   t jddddS )Nr.   rL   r/   Zpadding_idxrd   r  r&   r&   r&   r'   r:     r;   z3torch::nn::EmbeddingBagOptions(4, 3).padding_idx(1)c                   C   s   t t dt dgS NrL   r1   stackrT   r&   r&   r&   r'   r:     r;   )ru   rv   r?   rx   rR   rD   ZEmbeddingBag_sum_padding_idxc                	   C   s   t jddd dddddS )Nr.   rL   r  Fr3  r/   r  r  r&   r&   r&   r'   r:     r;   ztorch::nn::EmbeddingBagOptions(4, 3)
                                .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kSum).padding_idx(1)c                   C   s   t t dt dgS r  r  r&   r&   r&   r'   r:     r;   ZEmbeddingBag_max_padding_idxc                	   C   s   t jddd dddddS )Nr.   rL   r  Fr  r/   r  r  r&   r&   r&   r'   r:     r;   ztorch::nn::EmbeddingBagOptions(4, 3)
                                .max_norm(c10::nullopt).norm_type(2.).scale_grad_by_freq(false).mode(torch::kMax).padding_idx(1)c                   C   s   t t dt dgS r  r  r&   r&   r&   r'   r:     r;   ZEmbeddingBag_sparsec                   C   s   t jdddtjdS )Nr.   rL   T)sparser   )rd   r  r1   rU   r&   r&   r&   r'   r:     r;   zbtorch::nn::EmbeddingBagOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))c                   C   s   t dddS NrG   r/   r1   rT   repeatr&   r&   r&   r'   r:      r;   )ru   rv   r?   rx   rR   has_sparse_gradientsc                   C   s   t jddtjddS )Nr.   rL   T)r   r  )rd   r~  r1   rU   r&   r&   r&   r'   r:     r;   z_torch::nn::EmbeddingOptions(4, 3).sparse(true)._weight(torch::rand({4, 3}).to(torch::kFloat64))c                   C   s   t dddS r  r  r&   r&   r&   r'   r:     r;   ZEmbedding_sparse)rv   r?   rx   ru   rR   r  ZPixelShufflerd  z!torch::nn::PixelShuffleOptions(3))r/   	   r.   r.   ZPixelUnshufflez#torch::nn::PixelUnshuffleOptions(3))r/   r/      r  r  Znearest)rV   scale_factormodezF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12})).scale_factor(c10::nullopt).mode(torch::kNearest))r/   rG   r.   Zinterpolate_nearest_1d)rv   cpp_options_argsr@   ru   rz   rD   )r   rG   r.   Zinterpolate_nearest_1d_zero_dim)rv   r  r@   ru   rz   )r  )r/   rG   rL   Zinterpolate_nearest_tuple_1dg      @zF::InterpolateFuncOptions()
                            .size(c10::nullopt).scale_factor(std::vector<double>({4.})).mode(torch::kNearest)Zinterpolate_nearest_scale_1dZlinear)rV   r  r  Zalign_cornerszF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kLinear)
                            .align_corners(false)Zinterpolate_linear_1d)r.   zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kLinear)
                            .align_corners(false)Zinterpolate_linear_tuple_1dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({4.}))
                            .mode(torch::kLinear)
                            .align_corners(false)Zinterpolate_linear_scale_1dZinterpolate_linear_1d_zero_dimzF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kLinear)
                            .align_corners(true)Z#interpolate_linear_1d_align_cornerszF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({4.}))
                            .mode(torch::kLinear)
                            .align_corners(true)Z)interpolate_linear_scale_1d_align_cornersrG   zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({2, 2}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kNearest))r/      r/   r/   Z%interpolate_nearest_2d_launch_configszF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kNearest))r/   rG   r.   r.   Zinterpolate_nearest_2d)r     zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 16}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kNearest))r/   rG   rL   r.   Zinterpolate_nearest_tuple_2dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({4., 4.}))
                            .mode(torch::kNearest)Zinterpolate_nearest_scale_2d)r   rG   r.   r.   Zinterpolate_nearest_2d_zero_dimZbilinearzF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kBilinear)
                            .align_corners(false)Zinterpolate_bilinear_2dZ interpolate_bilinear_2d_zero_dim)r.   rQ   zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4, 6}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kBilinear)
                            .align_corners(false))r/   rG   rG   rL   Zinterpolate_bilinear_tuple_2dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({4., 4.}))
                            .mode(torch::kBilinear)
                            .align_corners(false)Zinterpolate_bilinear_scale_2d)r  r  zF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({2., 2.}))
                            .mode(torch::kBilinear)
                            .align_corners(false)Z*interpolate_bilinear_scale_tuple_shared_2d)r        ?zF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({2., 1.}))
                            .mode(torch::kBilinear)
                            .align_corners(false)Z*interpolate_bilinear_scale_tuple_skewed_2dzF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4, 6}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kBilinear)
                            .align_corners(true)Z+interpolate_bilinear_tuple_2d_align_cornerszF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({2., 1.}))
                            .mode(torch::kBilinear)
                            .align_corners(true)Z8interpolate_bilinear_scale_tuple_skewed_2d_align_cornersZbicubiczF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kBicubic)
                            .align_corners(false)Zinterpolate_bicubic_2dZinterpolate_bicubic_2d_zero_dimzF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4, 6}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kBicubic)
                            .align_corners(false)Zinterpolate_bicubic_tuple_2dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({4., 4.}))
                            .mode(torch::kBicubic)
                            .align_corners(false)Zinterpolate_bicubic_scale_2dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({2., 2.}))
                            .mode(torch::kBicubic)
                            .align_corners(false)Z)interpolate_bicubic_scale_tuple_shared_2dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({2., 1.}))
                            .mode(torch::kBicubic)
                            .align_corners(false)Z)interpolate_bicubic_scale_tuple_skewed_2dzF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4, 6}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kBicubic)
                            .align_corners(true)Z*interpolate_bicubic_tuple_2d_align_cornerszF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({2., 1.}))
                            .mode(torch::kBicubic)
                            .align_corners(true)Z7interpolate_bicubic_scale_tuple_skewed_2d_align_cornerszF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 12, 12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kNearest))r/   rG   r.   r.   r.   Zinterpolate_nearest_3d)r   rG   r.   r.   r.   Zinterpolate_nearest_3d_zero_dim)r  r  r  zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 16, 16}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kNearest))r/   rG   rL   r.   r.   Zinterpolate_nearest_tuple_3dzF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({4., 4., 4.}))
                            .mode(torch::kNearest)Zinterpolate_nearest_scale_3dZ	trilineara   F::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({12, 12, 12}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kTrilinear)
                            .align_corners(false)Zinterpolate_trilinear_3dZ!interpolate_trilinear_3d_zero_dim)r.   rQ   rQ   zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4, 6, 6}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kTrilinear)
                            .align_corners(false))r/   rG   rG   rL   rL   Zinterpolate_trilinear_tuple_3dg      @zF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({3., 3., 3.}))
                            .mode(torch::kTrilinear)
                            .align_corners(false)Zinterpolate_trilinear_scale_3dr   )rv   r  r@   ru   r   rz   rD   zF::InterpolateFuncOptions()
                            .size(std::vector<int64_t>({4, 6, 6}))
                            .scale_factor(c10::nullopt)
                            .mode(torch::kTrilinear)
                            .align_corners(true)Z,interpolate_trilinear_tuple_3d_align_cornerszF::InterpolateFuncOptions()
                            .size(c10::nullopt)
                            .scale_factor(std::vector<double>({3., 3., 3.}))
                            .mode(torch::kTrilinear)
                            .align_corners(true)Z,interpolate_trilinear_scale_3d_align_cornersr0   r0  zF::SoftmaxFuncOptions(-1))rG   r  Zsoftmax_lastdimr/   )r1  r   z/F::SoftmaxFuncOptions(1).dtype(torch::kFloat64)Zsoftmax_lastdim_dtype)rv   r  r@   ru   rz   rH   rD   zF::SoftmaxFuncOptions(1))rG   r  rG   rG   Zsoftmax_spatial_special)rG   rG   r.   r.   Zsoftmax_spatialZsoftmax_spatial_dtypezF::SoftmaxFuncOptions(0)Zsoftmax_functional_dim0)rv   r  r@   ru   rH   rz   rD   rL   zF::SoftmaxFuncOptions(3)Zsoftmax_functional_dim3r&   Zsoftmax_functional_scalar)rv   r  r@   ru   rH   rz   zF::LogSoftmaxFuncOptions(-1)Zlog_softmax_lastdimzF::LogSoftmaxFuncOptions(1)Zlog_softmax_spatial_specialZlog_softmax_spatialzF::LogSoftmaxFuncOptions(0)Zlog_softmax_dim0zF::LogSoftmaxFuncOptions(3)Zlog_softmax_dim3Zlog_softmax_scalarUnfoldc                   C   s   t ddddS )Nrm  rn  r   r   rd   r  r&   r&   r&   r'   r:   n	  r;   zPtorch::nn::UnfoldOptions({2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1}))rG   r.   rL   rL   )ru   rv   r?   r@   rR   rH   rD   Foldc                   C   s   t dddddS Nrl  rm  rn  r  rd   r  r&   r&   r&   r'   r:   w	  r;   zVtorch::nn::FoldOptions({3, 3}, {2, 2}).dilation({1, 1}).padding({0, 0}).stride({1, 1}))rG   r  r.   ZFold_no_batch_dim_inputc                   C   s   t dddddS r  r  r&   r&   r&   r'   r:   	  r;   )r  r.   )ru   rv   r?   r@   rR   refrH   rD   ZUnfold_int_inputc                   C   s   t ddddS )NrG   r/   r   r  r&   r&   r&   r'   r:   	  r;   z<torch::nn::UnfoldOptions(2).dilation(1).padding(0).stride(1)ZFold_int_inputc                   C   s   t dddddS NrL   rG   r/   r   r  r&   r&   r&   r'   r:   	  r;   z=torch::nn::FoldOptions(3, 2).dilation(1).padding(0).stride(1)ZFold_no_batch_dim_int_inputc                   C   s   t dddddS r  r  r&   r&   r&   r'   r:   	  r;   )ru   rv   r?   r@   r  rR   rH   rD   Zwith_up_down_scalarZPairwiseDistancec                   C   s   t ddt ddfS Nr,   r-   r   r&   r&   r&   r'   r:   	  r;   )r=   rx   rD   c                   C   s   t ddt ddfS )Nr,   r/   r-   r   r&   r&   r&   r'   r:   	  r;   Zbroadcast_lhs)r=   rx   rE   rD   c                   C   s   t ddt ddfS )Nr,   r-   r/   r   r&   r&   r&   r'   r:   	  r;   Zbroadcast_rhs)g      ?r   TzDtorch::nn::PairwiseDistanceOptions().p(1.5).eps(1e-05).keepdim(true)c                   C   s   t ddt ddfS r  r   r&   r&   r&   r'   r:   	  r;   Zwith_non_default_args)r=   r>   r?   rx   rE   rD   c                   C   s   t dt dfS )Nr-   r   r&   r&   r&   r'   r:   	  r;   )r=   rx   rA   rE   rD   ZTransformerEncoderLayer)r.   rG   r          ztorch::nn::TransformerEncoderLayerOptions(4, 2)
                                .dim_feedforward(16)
                                .dropout(0.0)rw  Zrelu_activationrI   )
r=   r>   r?   r@   rE   rB   rC   rS   rR   rD   r.   r-   r  ztorch::nn::TransformerEncoderLayerOptions(4, 2)
                                .dim_feedforward(8)
                                .dropout(0.0)
                                .activation(torch::kGELU)Zgelu_activationg{Gz?)	r=   r>   r?   r@   rR   rE   rB   rC   rD   ZTransformerDecoderLayer)r.   rG   r-   r  ztorch::nn::TransformerDecoderLayerOptions(4, 2)
                                .dim_feedforward(8)
                                .dropout(0.0)c                   C   s   t dddt dddfS NrL   r.   rG   ro   r&   r&   r&   r'   r:   	  r;   )	r=   r>   r?   rx   rR   rE   rB   rC   rD   ztorch::nn::TransformerDecoderLayerOptions(4, 2)
                                .dim_feedforward(8)
                                .dropout(0.0)
                                .activation(torch::kGELU)c                   C   s   t dddt dddfS r  ro   r&   r&   r&   r'   r:    
  r;   ZTransformera  torch::nn::TransformerOptions()
                                .d_model(4)
                                .nhead(2)
                                .num_encoder_layers(2)
                                .num_decoder_layers(2)
                                .dim_feedforward(8)
                                .dropout(0.0)
                                .activation(torch::kReLU)c                   C   s&   t dddt dddt ddfS r  ro   r&   r&   r&   r'   r:   
  r;   Zmultilayer_codergQ?)rL   rJ   ztorch::nn::LinearOptions(3, 5)c                   C   s
   t dS r  ro   r&   r&   r&   r'   r:   
  r;   c                 C   s*   t | dd|d  d|d  S )Nr/   r0   r   )r1   r2   r4   r3   r6   r&   r&   r'   r:   
  r;   )	r=   r>   r?   rx   rA   rE   rB   rC   rD   z5torch::nn::FlattenOptions().start_dim(-3).end_dim(-1))r0   )rL   r.   rJ   )r=   r?   r>   r@   rA   rE   rD   Z	Unflattenz'torch::nn::UnflattenOptions(-2, {2, 2})Z	LayerNorm8   zMtorch::nn::LayerNormOptions({56, 56, 56}).eps(1e-5).elementwise_affine(false))r.   r  r  r  Z3d_no_affine_large_feature)	r=   r>   r?   r@   rV  Z
check_evalgradcheck_fast_mode
check_halfrE   ZreflectZcircularZ	replicater  ztorch::kReflectztorch::kCircularztorch::kReplicateztorch::kZeros{z, }rm  rt  c                 c   s   | ]}|d  V  qdS r/   Nr&   rG  r8   r&   r&   r'   	<genexpr>S
  s     r  ZConvdztorch::nn::ConvzvdOptions(2, 3, 3)
                                        .stride(2)
                                        .padding(z)
                                        .dilation(1)
                                        .groups(1)
                                        .bias(true)
                                        .padding_mode()Z_stride2_pad2)
r=   r>   r?   r@   output_sizerV  rE   rB   rC   rD   ELU
HardshrinkHardsigmoidHardtanh	Hardswish	LeakyReLU
LogSigmoidPReLUReLUReLU6SELUCELUGELUGLUr   SiLUMishSoftplus
SoftshrinkSoftsignTanh
Tanhshrink	Threshold)r  )r>   rD   r>   )rR   Z	check_jitrD   )rH   rD   rD   )r  r  r  r  rF   r  r  r  r  r  r  r  r  r  r  r  r  r  r   r  r  r  r  r  !non_linear_activations_extra_info)r=   r@   rA   rE   test_cpp_api_parityc                 C   sr   |rt |||   }n|| |   }|dkr8| S |dkrH| S |dkrn| dkrn| |d S |S )Nr  r3  Z	batchmeanr   )r1   rs   r   r  r3  r1  rV   )r   r   r   r   r%   r&   r&   r'   kldivloss_reference
  s    r  c                 C   s   |   dkst| d}| d}|f|  dd   }t|| }|d krbt|| }d}	tdd |D  D ]V}
||
 }||krdn
||  }t	|
}|
d| | t|  | ||
< |	|7 }	qx|dkr| |	 S |d	kr| S |S )
NrL   r   r/   rG   c                 S   s   g | ]}t |qS r&   )range)rG  rV   r&   r&   r'   rH  
  s     z'nlllossNd_reference.<locals>.<listcomp>r  r  r3  )r1  r"   rV   r1   r  rk   onesr   itemrK  insertrL  r3  )r   r   r)   r   r   NCZout_sizeoutputZtotal_weighttupZt_nxnormZinput_indexr&   r&   r'   nlllossNd_reference
  s(    


r  c                 C   s   |   dkstt| d} | d}|d kr>t|| }|jd|fdd | jdd  D  }|dkr|dksxt|d|  ||  }| | | j	dd }|dkr|
 S |d	kr|	 S |S )
NrG   r/   c                 s   s   | ]
}d V  qdS r  r&   rG  r9   r&   r&   r'   r  
  s     z;cross_entropy_loss_prob_target_reference.<locals>.<genexpr>r  r  r0  r  r3  )r1  r"   r1   log_softmaxrV   r  rk   r4   shaper3  r  )r   r   r)   r   label_smoothingr  r  r&   r&   r'   (cross_entropy_loss_prob_target_reference
  s    
&r  c                 C   s4  t | d}tj|||||d}|dkr,|S d|  k r@dksFn tt | d} | d}|d k	r| |jd|fdd | jdd  D   } t | d }	||k}
|		|
d |dkr|d k	rt |	|
d	||
    }nt |	|
 }n|d
krt |	}n|	}d| | |||   S )Nr/   r   r  r  c                 s   s   | ]
}d V  qdS r  r&   r  r&   r&   r'   r  
  s     z>cross_entropy_loss_indices_target_reference.<locals>.<genexpr>rG   r  r   r3  )r1   r  rj   r   r"   rV   r4   r  r3  Zmasked_fill_ZgatherZmasked_selectZlogical_notrN   r  )r   r   r)   r   r   r  Zlog_softmax_inputZnlllossr  Zsmooth_lossZignore_maskretr&   r&   r'   +cross_entropy_loss_indices_target_reference
  s4    
*
*
r  c                 C   s6   | j |j krt| ||||dS t| |||||dS d S )N)r)   r   r  )r)   r   r   r  )r  r  r  )r   r   r)   r   r   r  r&   r&   r'   cross_entropy_loss_reference
  s         r  c           	         sj   dd  fddt | |D }t | \}}| |}|dkrRt|t| S |dkrbt|S |S d S )Nc                 S   s6   ||krdS |d krdn|| }| |  | }||fS )Nr  r/   r&   )r   r   r)   r   r  r%   r&   r&   r'   nll_loss_helper  s
    z*nllloss_reference.<locals>.nll_loss_helperc                    s   g | ]\}}|| qS r&   r&   )rG  r7   r3   r   r  r)   r&   r'   rH    s   z%nllloss_reference.<locals>.<listcomp>r  r3  )zipZ
new_tensorr3  )	r   r   r)   r   r   Zlosses_and_weightslossesr*   Zlosses_tensorr&   r  r'   nllloss_reference  s    
r  r  c                 C   s~   | |   }||k|}||k |}|dkr6|}n$||d|   |d |d  |  }|dkrj| S |dkrz| S |S )Nr   r  rG   r  r3  )r   rk   r  r3  )r   r   r   r
  abs_diffZge_beta_maskZlt_beta_maskr  r&   r&   r'   smoothl1loss_reference%  s    $r  c                 C   sd   | |   }||k}||k }|| |d|   |d |d   }|dkrP| S |dkr`| S |S )Nr  rG   r  r3  )r   r  r3  )r   r   r   r  r  Zge_delta_maskZlt_delta_maskr  r&   r&   r'   huberloss_reference5  s    $r  c                 C   sp   g }|D ]}|dk r q$| | qd}|D ]>}tdt| D ]*}||kr>|tdd| |  | |  7 }q>q,|S )Nr   r/   )appendr  lenr  )r   r   targetsZtarget_indexr3  r7   r&   r&   r'   _multilabelmarginloss_referenceA  s    "r  c                 C   s   |   }|   dk rp|  dk s$t|   dkr:| dn| dd} |  dkr`|dn|dd}| d}| d}| | }td|D ]}t| | || ||< q|dkr| | S |dkr|	 | S |dk r|
 | S || S d S NrG   r/   r   r  r3  )r1  r"   rE  rV   newzero_r  r  r  r3  rQ  )r   r   r   Z	input_dimnr1  r  r7   r&   r&   r'   multilabelmarginloss_referenceQ  s"    &&

r  c                 C   sL   ||  j dd| }t|dk| |}|dkr8| S |dkrH| S |S )Nr   r  r/   r  r3  )r  rk   r1   wherer  r3  )r   r   r*  r   Zmargin_clampr  r&   r&   r'   hingeembeddingloss_referencek  s    r  c                 C   s:   d|  |     }|dkr&| S |dkr6| S |S )Nr/   r  r3  )rs   r   r  r3  )r   r   r   r  r&   r&   r'   softmarginloss_referencev  s    r  c                 C   sj   |d kr|  t| d}d}tdt| D ]6}||kr.||| td|| |  | |  |  7 }q.|S )Nr/   r   )r  r  Zfill_r  r  )r   Z
target_idxr8   r*  r)   r  r7   r&   r&   r'   _multimarginloss_reference  s    ,r  c                 C   s   |   dk r2|   dkr"| dn| dd} |  }|  dkrP|d}| d}| d}| |}	td|D ] }
t| |
 ||
 ||||	|
< qx|dkr|	 | S |dkr|	 | S |dkr|	d| S |	| S r  )	r1  rE  rV   r  r  r  r  r3  rQ  )r   r   r8   r*  r)   r   Z
target_dimr  r1  r  xr&   r&   r'   multimarginloss_reference  s"    &



r  c                 C   sZ   dd }t |dkd|| | || || jdd}|dkrF| S |dkrV| S |S )Nc                 S   sv   |  | d}td| dD ]P}| | ||   | | | |   d || ||   d  d  ||< q |S )Nr   g-q=r  )r  rV   r  r3  )abcosr7   r&   r&   r'   _cos  s    Nz+cosineembeddingloss_reference.<locals>._cosr/   r   r  r  r3  )r1   r  r  r  r3  )input1input2r   r*  r   r  r  r&   r&   r'   cosineembeddingloss_reference  s    .r  ư>c                 C   sz   t | |||}t | |||}	|r@t ||||}
t |	|
}	t j|| |	 dd}|dkrf| S |dkrv| S |S )Nr  r  r  r3  )r1   Zpairwise_distancer  r  r  r3  )anchorZpositivenegativer*  r8   epsZswapr   Zd_pZd_nZd_sr  r&   r&   r'   tripletmarginloss_reference  s    r  c                 C   s>   | | |  | j dd}|dkr*| S |dkr:| S |S )Nr   r  r  r3  )r  r  r3  )r  r  r   r*  r   r  r&   r&   r'   marginrankingloss_reference  s    r  c                 C   s  t j|t jd}t j|t jd}| j}|  } | }|d}g }t| dD ]v}	||	  }
||	  }||	  }|	d| d f|}|
 dkr||	d |f |dd d< n||| | |dd d< | d |
|	f  }| |d d f}|d|f |d< |d|d f |d< |d d |dd  k}td|
D ]f}| }|dd   |d d 7  < |dd   t ||d d |d7  < |||f | }qB||dd    d    qRt |d}|dkr||j|j|jd  }n|d	kr| }||}|S )
Nr   r   r/   rG   r  r0   r  r   devicer3  )r1   Z	as_tensorr   r   rU   Zcumsumr  rV   r  Znew_fullr1  rs   Z	new_zeroscloner  r  r3  r   catr   r  r  )Z	log_probsr  Zinput_lengthsZtarget_lengthsblankr   dtZcum_target_lengthsr  r7   Zinput_lengthZtarget_lengthZcum_target_lengthZtargets_primeZprobsalphaZ
mask_thirdr3   Z
alpha_nextr  r&   r&   r'   ctcloss_reference  sB    
,"


r
  )r   )r   r   r   r   r  r  r  r#  r,  r:  CosineEmbeddingLossTripletMarginLossMarginRankingLossZCTCLossZCrossEntropyLossr   r   c                     s`   | d }dd  fdd  fdd| dd D }|| }t |}|d	kr\|d
S |S )zReference function for criterion supporting no batch dimensions.

    The criterion is passed the input and target in batched form with a single item.
    The output is squeezed to compare with the no-batch input.
    r0   c                 S   s&   t | ttfrdd | D S | dS )Nc                 S   s   g | ]}| d qS rC  rD  rF  r&   r&   r'   rH    s     zNsingle_batch_reference_criterion_fn.<locals>.unsqueeze_inp.<locals>.<listcomp>r   rI  rM  r&   r&   r'   rO    s    z:single_batch_reference_criterion_fn.<locals>.unsqueeze_inpc                    s:   g }t | ttfr,| D ]}| | qn
||  |S rY   )rJ  rK  rL  extendr  )Zxsr%   r  )rN   r&   r'   rN     s    
z4single_batch_reference_criterion_fn.<locals>.flattenc                    s   g | ]} |qS r&   r&   )rG  r   )rO  r&   r'   rH    s     z7single_batch_reference_criterion_fn.<locals>.<listcomp>Nrh   r   )r(   rQ  )r[   	criterionZsingle_batch_input_argsr  r   r&   )rN   rO  r'   #single_batch_reference_criterion_fn  s    	
r  ZL1LossZMSELossZPoissonNLLLossr  r  rh   Z_no_batch_dim_)namec                 G   s   t t| tdS Nri   r    rd   r   r  r[   r&   r&   r'   r:   0  r;   )ru   rv   r@   Ztarget_sizerA   r  rD   ZKLDivLoss_no_batch_dim_c                   C   s   t jtdS r  )rd   r   r   r&   r&   r&   r'   r:   =  r;   c                   C   s   t d S Nrd  r   r&   r&   r&   r'   r:   >  r;   c                   C   s
   t dS r  ro   r&   r&   r&   r'   r:   ?  r;   )ru   rv   rx   	target_fnrA   r  rD   BCELossc                   C   s   t t jdt jdS Nr  r   )r1   r   r|   rU   r&   r&   r&   r'   r:   L  r;   c                   C   s   t jdt jddt jS )Nr  r   r   )r1   r|   rU   r   r   r&   r&   r&   r'   r:   M  r;   BCEWithLogitsLossc                   C   s   t jdt jdS r  r1   r|   rU   r&   r&   r&   r'   r:   O  r;   r#  c                   C   s   t jdt jdS r  r  r&   r&   r&   r'   r:   P  r;   c                   C   s   t dddgd S Nr0   r/   rL   r1   tensorr&   r&   r&   r'   r:   P  r;   r  c                   C   s   t jdt jdS )Nr.   r   r  r&   r&   r&   r'   r:   Q  r;   c                   C   s   t ddddgS )NrL   r   r0   r/   r  r&   r&   r&   r'   r:   Q  r;   r,  c                   C   s   t jdt jdS r  r  r&   r&   r&   r'   r:   R  r;   c                   C   s   t dddgd S r  r  r&   r&   r&   r'   r:   R  r;   r   c                   C   s   t jtjdtjdddS )NrL   r   r   r0  )rj   r  r1   r|   rU   r&   r&   r&   r'   r:   S  r;   c                   C   s
   t dS rM   r  r&   r&   r&   r'   r:   S  r;   r  c                   C   s    t jdt jdt jdt jdfS r  r  r&   r&   r&   r'   r:   V  r;   c                   C   s   t jdt jdS )Nr/   r   )r1   r  rU   r&   r&   r&   r'   r:   W  r;   r  c                   C   s   t dt dfS r   r   r&   r&   r&   r'   r:   Z  r;   c                   C   s   t d S r   )r1   r|   signr&   r&   r&   r'   r:   Z  r;   r  c                   C   s    t jdt jdt jdt jdfS r  r  r&   r&   r&   r'   r:   ^  r;   c                   C   s   t jdt jdS r  r  r&   r&   r&   r'   r:   _  r;   ZMultiLabelSoftMarginLossc                   C   s   t jdt jdS r  r  r&   r&   r&   r'   r:   a  r;   c                   C   s
   t dS )Nr  r   r&   r&   r&   r'   r:   a  r;   rR   ,classification_criterion_no_batch_extra_info)r  r  r#  r   r,  c                 G   s   t t| tdS r  r  r  r&   r&   r'   r:   s  r;   c                 C   s   |  S rY   r&   fr&   r&   r'   r:   t  r;   c                 C   s   |  S rY   r&   r   r&   r&   r'   r:   u  r;   )ru   rv   rx   r  rA   r  Z
has_parityc                	   @   s   e Zd Zedd Zeejeeej	 eej	 f dddZ
eejddddZedejeejeejeej f ed
ddZdd Zdd Zdd ZdedddZdedddZdedddZdS )
NNTestCasec                 O   s   t d S rY   NotImplementedErrorrZ   r[   r^   r&   r&   r'   _forward  s    zNNTestCase._forward)rS  returnc                 C   s   t d S rY   r#  rZ   rS  r&   r&   r'   _get_parameters  s    zNNTestCase._get_parametersNc                 C   s   t d S rY   r#  r(  r&   r&   r'   _zero_grad_parameters  s    z NNTestCase._zero_grad_parametersF)rS  r   r  grad_outputcreate_graphc                 C   s   t d S rY   r#  )rZ   rS  r   r  r+  r,  r&   r&   r'   	_backward  s    zNNTestCase._backwardc                    sT   t |tr"t fdd|D S t |tr@ fdd|D S t|  S d S )Nc                 3   s   | ]} | V  qd S rY   	_jacobianrG  elemnum_outrZ   r&   r'   r    s     z'NNTestCase._jacobian.<locals>.<genexpr>c                    s   g | ]} | qS r&   r.  r0  r2  r&   r'   rH    s     z(NNTestCase._jacobian.<locals>.<listcomp>)rJ  rL  rK  r1   r  nelement)rZ   r   r3  r&   r2  r'   r/    s
    

zNNTestCase._jacobianc                    sF   t |tjr,|jr | dS |dS nt fdd|D S d S )Nr0   c                 3   s   | ]}  |V  qd S rY   )_flatten_tensors)rG  r  rZ   r&   r'   r    s     z.NNTestCase._flatten_tensors.<locals>.<genexpr>)rJ  r1   rP  Z	is_sparseZto_denser4   rL  )rZ   r  r&   r6  r'   r5    s
    zNNTestCase._flatten_tensorsc                 C   sJ   t |tjr2|jrF|jd k	rF|j  |j  n|D ]}| | q6d S rY   )rJ  r1   rP  rz  gradr  Zdetach__zero_grad_input)rZ   r   r7   r&   r&   r'   r8    s    
zNNTestCase._zero_grad_inputTr   c                 C   sX  |  ||}| }|r0| ||}tt|}|r\tdd | |d D }	t|	|}
t	|D ]}| |\}}dd t
||D }t|}|d}d||< |r| | |r| | | ||||}|r
t
|t|D ]"\}}| d|d d |f< q|rdt| |d|
d d |f< qdt }|rD||f7 }|rT||
f7 }|S )Nc                 s   s   | ]}|  V  qd S rY   )numelr  r&   r&   r'   r    s     z2NNTestCase._analytical_jacobian.<locals>.<genexpr>r   c                 S   s&   g | ]\}}|d krt |n|qS rY   )r1   
zeros_like)rG  r8   r  r&   r&   r'   rH    s     z3NNTestCase._analytical_jacobian.<locals>.<listcomp>r0   r/   )r&  r4  r/  rK  r   r3  r)  r1   r  r  r  r;  r4   r*  r8  r-  
contiguousr  r5  rL  )rZ   rS  r   jacobian_inputjacobian_parametersr  r  Zjacobian_inpZflat_jacobian_inputZ	num_paramZjacobian_paramr7   paramd_paramZd_outZ
flat_d_outd_inputZ
jacobian_xZd_xresr&   r&   r'   _analytical_jacobian  s<    



 

zNNTestCase._analytical_jacobianc                    s    fdd}t  }|r,|t||ddf7 }|r \}}g }	|D ]&}
t|||
dd}|	|d d  qF|t|	df7 }|S )Nc                     s     |  S rY   )r&  detachr9  rS  rZ   r&   r'   fw  s    z*NNTestCase._numerical_jacobian.<locals>.fwr  )r   )r   r   r   )rL  r   r)  r  r1   r  )rZ   rS  r   r=  r>  rF  rB  r?  r9   Zto_catr8   Zjacobianr&   rE  r'   _numerical_jacobian  s    zNNTestCase._numerical_jacobianc                 C   s   t | |d }| ||||}| ||||}tt|}tt|}g }	t||D ]0\}
}|
 dkrX|	|
j	|dd
   qXt|	dkr| t|	t d S )Nr   r0   )r	  )boolr)  rC  rG  rK  r   r  r:  r  r   r   r  r  ZassertLessEqual	PRECISION)rZ   rS  r   r=  r>  Z
analyticalZ	numericalZanalytical_tZnumerical_tZdifferencesr  r  r&   r&   r'   check_jacobian  s    zNNTestCase.check_jacobian)F)TT)TT)T)r`   ra   rb   r   r&  rd   re   r   r   	Parameterr)  r*  r   r1   rP  r   r   rH  r-  r/  r5  r8  rC  rG  rJ  r&   r&   r&   r'   r"    s(   
(  		*r"  c                   @   sb   e Zd ZdddhZdddZdd	 Zd
d Zedd Zedd Z	dd Z
dddZdd ZdS )TestBaser>   r   
extra_args Nc                 K   s   || _ || _|| _|| _| jD ]P}||kr|d |kr|d |kr|dkrVt ||< qt|   d| dq|| _i | _	d S )N_fn_size>   r>   rM  z
: Specify z5 by a value, a function to generate it, or it's size!)
rE   ru   rv   rA   _required_arg_namesrL  
ValueErrorget_name_extra_kwargs
_arg_cache)rZ   rv   rE   rA   ru   r^   r  r&   r&   r'   __init__  s    
 zTestBase.__init__c                 C   s8   | j d k	rd| j  S d| jj }| jr4|d| j 7 }|S )NZtest_r9   )ru   rv   r`   rE   )rZ   Z	test_namer&   r&   r'   rS    s    

zTestBase.get_namec                    s:   t |tjr|S t|r2t| fdd|D S |S d S )Nc                 3   s   | ]}  |V  qd S rY   )_unpack)rG  vr6  r&   r'   r    s     z#TestBase._unpack.<locals>.<genexpr>)rJ  r1   rP  r   type)rZ   valuer&   r6  r'   rW    s
    zTestBase._unpackc                 C   s   |  ddS )Nr>   T_get_argr6  r&   r&   r'   r>     s    zTestBase.constructor_argsc                 C   s   |  ddS )NrM  Tr[  r6  r&   r&   r'   rM  !  s    zTestBase.extra_argsc              
      s   || j kst|| jkr|d }|d }|| jkrD| j| | j|< nl|| jkrb| j|  | j|< nN|| jkstd| d| d| d|    fdd  | j| | j|< |r| | j| S | j| S )	NrO  rP  z	Missing `z`, `z` or `z` for c                    s>   t | tr fdd| D S t | tjr0|  S t| S d S )Nc                    s   g | ]} |qS r&   r&   )rG  smap_tensor_sizesr&   r'   rH  6  s     z?TestBase._get_arg.<locals>.map_tensor_sizes.<locals>.<listcomp>)rJ  rK  r1   rP  rU   r|   )sizesr^  r&   r'   r_  4  s
    
z+TestBase._get_arg.<locals>.map_tensor_sizes)rQ  r"   rU  rT  rS  rW  )rZ   r  unpackfn_nameZ	size_namer&   r^  r'   r\  %  s    


zTestBase._get_argTc                 C   s   |  d|S )Nr   r[  )rZ   ra  r&   r&   r'   
_get_input@  s    zTestBase._get_inputc                 C   s   t d S rY   r#  )rZ   	test_caser&   r&   r'   __call__C  s    zTestBase.__call__)rN  NN)T)r`   ra   rb   rQ  rV  rS  rW  propertyr>   rM  r\  rc  re  r&   r&   r&   r'   rL    s   

	


rL  c                       sV   e Zd ZeeejeedddZ fddZdd Z	dd	 Z
d
d Zdd Z  ZS )
ModuleTest)rd  rS  r   r'  c                 C   s   t d S rY   r#  )rZ   rd  rS  r   r&   r&   r'   _do_testI  s    zModuleTest._do_testc                    s   t  j|| |dd| _|dd| _|dd| _|dd| _|dd| _|dd	| _|d
d| _	|dd | _
| j
d krt | _
d S )Nr=  TrH   rz   rR   !FIXME_no_cuda_gradgrad_comparisonFr   g-C6*?check_forward_onlyrD   )superrV  getr=  should_test_cudashould_test_picklerR   ri  r   rj  rD   r1   get_default_dtyper%  	__class__r&   r'   rV  M  s    

zModuleTest.__init__c           
   
   C   s  t | j | j| j }|  }| jd k	rn|||}t|}t|}| |||d |}|j	||dd | j
rW 5 Q R  d S | ||| | jrt N}||| t|| |d t|}	|	|||||	| W 5 Q R X | ||| W 5 Q R X d S )Nr   F)exact_dtype)r   rD   rv   r>   rc  rA   r&  r   r)  assertEqualrj  test_noncontigrn  tempfileTemporaryFiler1   saveseekloadrh  )
rZ   rd  rS  r   outZ	ref_inputZ
ref_moduleexpected_outr!  Zmodule_copyr&   r&   r'   re  [  s(    



&zModuleTest.__call__c                    s   t |tr fdd|D S t |tr<t fdd|D S |}| }|}t|D ]}||dkrT|d } qtqTtt||g|	|d
 }| dks| dks| rt|j|_|S )Nc                    s   g | ]}  |qS r&   noncontiguizerG  or6  r&   r'   rH  w  s     z,ModuleTest.noncontiguize.<locals>.<listcomp>c                 3   s   | ]}  |V  qd S rY   r|  r~  r6  r&   r'   r  y  s     z+ModuleTest.noncontiguize.<locals>.<genexpr>r/   r   )rJ  rK  rL  r1  r  rV   r1   r  Z
empty_likeselectrD  r:  Zis_contiguousr"   rz  )rZ   objr  ndimr1  r  Z	noncontigr&   r6  r'   r}  u  s    

"$zModuleTest.noncontiguizec              
   C   s  t |tjr| dkrd S tdd |D r2d S || || t d |||}t	|ddrn|d }|
|j }| }t|||||}t||d }W 5 Q R X | |}| |}	tddd	D ]\}
}|
r|n|}t|r|n|	}|| || t l |||}t	|ddr8|d }|||||}||| |j||d
dd |||d | W 5 Q R X qd S )Nr   c                 s   s&   | ]}t |tjr| d kV  qdS )r   N)rJ  r1   rP  r1  )rG  r7   r&   r&   r'   r    s      z,ModuleTest.test_noncontig.<locals>.<genexpr>return_indicesFr/   )TFrG   )r  g-C6?atolrtol)rJ  r1   rP  r1  anyr*  r8  r   r&  r    r  r  normal_r  r   r-  r)  r}  r   rs  )rZ   rd  rS  r   r  r+  rA  r@  Znc_inputZnc_grad_outputZcontig_iZcontig_gr7   gorz  r7  r&   r&   r'   rt    s:    





zModuleTest.test_noncontigc              	   C   s  t r
| jstdt| j |  }tjtj	i}t
|trD|n|f}tdd |D }t||d}| j| j }| j| j 	  }||}	||}
t|	d |
d D ]\}}|j| q|| || || || |||}|||}t|ddr"|d }|d }|j||| jddd td	D ]}|  }||}|||||}|||||}|j||| jddd t|	d
 |
d
 D ]\}}|j||| jdd qq@| j r| j!s|| }|| }t|ddr|d }|d }tj"|dd}||# }d|_$tj%j&||t|'  |dd}tj%j&||t|'  |dd}t||D ] \}}|j||| jddd qp|r|( ) t(dd |D  }|( ) t(dd |D  }n4|( t(dd |D  }|( t(dd |D  }tj%j&|||f t|'  dd}tj%j&|||f t|'  dd}|j||| jddd t||D ] \}}|j||| jddd qr| *||| W 5 Q R X d S )NExcluded from CUDA testsc                 s   s"   | ]}t |tjo|jjV  qd S rY   )rJ  r1   rP  r   
is_complexrF  r&   r&   r'   r    s     z'ModuleTest.test_cuda.<locals>.<genexpr>)type_mapr   r  Fr  r  rr  rJ   r/   r  T)rz  )r,  c                 s   s   | ]}|   V  qd S rY   r3  r   rG  r  r&   r&   r'   r    s     c                 s   s   | ]}|   V  qd S rY   r  r  r&   r&   r'   r    s     c                 s   s   | ]}|  V  qd S rY   r3  r  r&   r&   r'   r    s     c                 s   s   | ]}|  V  qd S rY   r  r  r&   r&   r'   r    s     )Zretain_graph)+r   rm  unittestSkipTestr   rD   rc  r1   rU   floatrJ  rL  r  r
   rv   r>   cudar)  r  r   Zcopy_r8  r*  r&  r    rs  r   r  r  r  rk   r-  rR   ri  
randn_likerD  rz  Zautogradr7  rR  r3  r   rt  )rZ   rd  	cpu_inputr  Zcpu_input_tupleZis_any_input_complexZgpu_input_tuple
cpu_module
gpu_moduleZ	cpu_paramZ	gpu_paramZcpu_pZgpu_p
cpu_output
gpu_outputr9   Zcpu_gradOutputZgpu_gradOutputcpu_gradInputgpu_gradInputZcpu_d_pZgpu_d_pZcpu_gradInputsZgpu_gradInputsZcpu_d_iZgpu_d_iZoutputs_cpuZoutputs_gpuZcpu_ggZgpu_ggr&   r&   r'   rH     s    








 zModuleTest.test_cuda)r`   ra   rb   r   r   rd   re   rh  rV  re  r}  rt  rH   __classcell__r&   r&   rp  r'   rg  G  s   %rg  c                   @   s   e Zd Zdd ZdS )InputVariableMixinc                    s    t | d} fdd  |S )NFc                    sD   t | tjr&|  s|  r"d| _| S t|  fdd| D S d S )NTc                 3   s   | ]} |V  qd S rY   r&   r0  map_variablesr&   r'   r    s     zGInputVariableMixin._get_input.<locals>.map_variables.<locals>.<genexpr>)rJ  r1   rP  is_floating_pointr  rz  rY  rl   r  r&   r'   r    s
    z4InputVariableMixin._get_input.<locals>.map_variables)rL  rc  )rZ   r   r&   r  r'   rc    s    zInputVariableMixin._get_inputN)r`   ra   rb   rc  r&   r&   r&   r'   r    s   r  c                       s@   e Zd Z fddZdd Zdd Zdd Zed	d
 Z  Z	S )NewModuleTestc                    s   t  j|| |dd| _|dd| _|dd| _|dd| _|dd| _|dd| _|d	d
| _	|dd| _
|dd| _|dd| _|dd | _|dd| _|dd| _d S )NrV  Fcheck_inplacerR   Tskip_doubler}  rB   rC   rP   test_cpur  rS   r  r   supports_fwgrad_bwgrad)rk  rV  rl  rV  r  rR   r  r}  rB   rC   r  r  rS   r  r   r  r%  rp  r&   r'   rV    s    zNewModuleTest.__init__c              	      s   t dd   D }t| fdd}| jrbdks@tt|d } |d | n"t	||| | j
| j| jd | jrt||| | j
| j| jd d S )	Nc                 s   s   | ]
}|V  qd S rY   r&   r  r&   r&   r'   r  .  s     z1NewModuleTest._check_gradients.<locals>.<genexpr>c                     s   |rt  | d  S rY   )r"   r&  )Zinputs_and_paramsr^   rS  Z
num_inputsrd  r&   r'   fn_to_gradcheck1  s    z7NewModuleTest._check_gradients.<locals>.fn_to_gradcheckr/   r   )rS   	fast_modeZcheck_forward_ad)rS   r  Zcheck_fwd_over_rev)rL  rR  r  r  r"   r1   r  rJ  Z
assertTruer   rS   r  r   rR   r   r  )rZ   rd  rS  input_tupleparamsr  Ztest_input_jacobianr&   r  r'   _check_gradients-  s$    zNewModuleTest._check_gradientsc              	      sr  t  }t d t|tr |n|f}|  |    | jrft|dksTt	|d }| j
| jddi}|j}t   |}W 5 Q R X |j| t|}	|	 }
t  ||
}W 5 Q R X |
j| || |j  }|jd k	rt   |j  W 5 Q R X |	jd k	rBt   |	j  W 5 Q R X || || |j|	j d fdd	}tdd |D r"tr"td	d |D }     |  |t jjd t j dkrdtd
d |D } d t jd  |  W 5 Q R X |t jjd nBdd fddfddfddtfdd|D }    |  |t j tfdd|D }    |  |t j  trd| j!rdtfdd|D }     |  |t jjd tdd |D } "   |  |t j tdd |D }    |  |t jjd | j#rt j$j#j%dd  |  |t jjd W 5 Q R X t j dkrtdd |D } d t jd  |  W 5 Q R X |t jjd | j&s$tfdd|D }     |  |t jj d | j'sdtfdd|D } (    |  |t jj)d t | d S )Nr/   r   ZinplaceTc                    s6      D ](}||  |d k	r| | qd S rY   )rR  ZassertIsInstancers  Z
get_device)Ztensor_typeZ	device_idr8   )rS  rd  r&   r'   assert_module_parameters_aret  s    z<NewModuleTest._do_test.<locals>.assert_module_parameters_arec                 s   s   | ]}t |tjV  qd S rY   )rJ  r1   Z
LongTensorrF  r&   r&   r'   r  z  s     z)NewModuleTest._do_test.<locals>.<genexpr>c                 s   s   | ]}|  V  qd S rY   r  rF  r&   r&   r'   r  }  s     c                 s   s   | ]}| d V  qdS r  r  rF  r&   r&   r'   r    s     c                 S   s,   |   r| |S |  r$| |S | S d S rY   )r  r   r  )r  realr|  r&   r&   r'   to_type  s
    

z'NewModuleTest._do_test.<locals>.to_typec                    s    | t jd S rY   )r1   Zfloat16r  r  r&   r'   to_half  s    z'NewModuleTest._do_test.<locals>.to_halfc                    s    | t jt jS rY   )r1   float32Z	complex64r  r  r&   r'   	to_single  s    z)NewModuleTest._do_test.<locals>.to_singlec                    s    | t jt jS rY   )r1   float64r{  r  r  r&   r'   	to_double  s    z)NewModuleTest._do_test.<locals>.to_doublec                 3   s   | ]} |V  qd S rY   r&   rF  r  r&   r'   r    s     c                 3   s   | ]} |V  qd S rY   r&   rF  r  r&   r'   r    s     c                 3   s   | ]} |  V  qd S rY   r  rF  r  r&   r'   r    s     c                 s   s   | ]}|  V  qd S rY   )cpurF  r&   r&   r'   r    s     c                 s   s   | ]}|  V  qd S rY   r  rF  r&   r&   r'   r    s     F)ZenabledrG   c                 s   s   | ]}| d V  qdS r  r  rF  r&   r&   r'   r    s     c                 3   s   | ]} |  V  qd S rY   r  rF  r  r&   r'   r    s     c                 3   s   | ]} |  V  qd S rY   r  rF  )r  r&   r'   r    s     )N)*r1   Zget_num_threadsZset_num_threadsrJ  rL  r  __repr__r  r  r"   rv   r>   _versionr   rs  r   r  ZassertNotEqualr   r  r7  Zno_gradr  backwardallr   r  r  ZFloatTensorZdevice_countr  rU   ZDoubleTensorrm  r  rV  backendsflagsr  r}  halfZ
HalfTensor)rZ   rd  rS  r   Znum_threadsr  Z	module_ipZinput_versionr  Zinput_ipZinput_ip_cloneZ	output_ipr7  r  r&   )rS  rd  r  r  r  r  r'   rh  I  s    









zNewModuleTest._do_testc                 C   s   |  ddS Nr   Fr[  r6  r&   r&   r'   _get_target  s    zNewModuleTest._get_targetc                 C   s   |  ddS Nr>   Fr[  r6  r&   r&   r'   r>     s    zNewModuleTest.constructor_args)
r`   ra   rb   rV  r  rh  r  rf  r>   r  r&   r&   rp  r'   r    s    r  c                       s\   e Zd ZejdhZ fddZdd ZdddZd	d
 Z	e
dd Ze
dd Z  ZS )CriterionTestr   c                    s   t  j|| |dd| _|dd| _|dd| _|dd| _|dd| _|dd| _|d	d| _	|d
d| _
|dd| _|dd| _|dd | _| jd krt | _d S )NrH   Trj  FrR   r  check_bfloat16check_complexr  rB   rC   rP   rS   rD   )rk  rV  rl  rm  rj  rR   r  r  r  r  rB   rC   rS   rD   r1   ro  r%  rp  r&   r'   rV    s    
zCriterionTest.__init__c           
   	      s,  t | j | j| j  |  }   t  |  }| jd k	r|j	 ||| j
d}t|t|f| j
  f }| j| }||| | jrW 5 Q R  d S tdd   D }t|ts|f| |f } fdd}	n|| |f } fdd}	t|	|| jd | jrt|	|| jd W 5 Q R X d S )NrM  c                 s   s   | ]
}|V  qd S rY   r&   r  r&   r&   r'   r    s     z)CriterionTest.__call__.<locals>.<genexpr>c                    s
    | |S rY   r&   )r   r   r  rS  r&   r'   apply_fn  s    z(CriterionTest.__call__.<locals>.apply_fnc                    s    | ||S rY   r&   )r  r  r   r  r  r&   r'   r    s    )rS   )r   rD   rv   r>   rc  r  r   r  rA   _forward_criterionrM  r   rs  rj  rL  rR  rJ  r   rS   rR   r   )
rZ   rd  r   r   rz  Zref_argsr{  r  inputsr  r&   r  r'   re    s,    


zCriterionTest.__call__Nc              	      s  d fdd	 t r| js"tdt| jL |  }|  }| j| j	 }| j| j	 } ||d}|
 st| r~ ||}|| || t|}t|}	|  |tjtjhkr|  }|  }| j| j	 }|j||||d}
|j|||	|d}|j|
||tjtjhkrdndd	dd
 |j|||
||d}|j||||	|d}|j|||tjtjhkrjdndd	dd
 W 5 Q R X d S )NFc                    sN   t | tjr"|  j dS t | trFt fdd| D S | S d S )Nr   c                 3   s   | ]} |V  qd S rY   r&   r~  )convert_dtyper   rz  r&   r'   r  #  s     zACriterionTest.test_cuda.<locals>.convert_dtype.<locals>.<genexpr>)rJ  r1   rP  rD  r   requires_grad_rL  )r  r   rz  r  ry  r'   r    s
    
z.CriterionTest.test_cuda.<locals>.convert_dtyper  Tr  rI   g-C6:?r   r  )F)r   rm  r  r  r   rD   rc  r  rv   r>   r  r  rY  r
   r  r1   r  bfloat16r  rs  Z_backward_criterion)rZ   rd  r   rM  r  Z
cpu_targetr  r  Z	gpu_inputZ
gpu_targetr  r  r  r  r&   r  r'   rH     s\    




            zCriterionTest.test_cudac                 C   s   |  ddS r  r[  r6  r&   r&   r'   r  Q  s    zCriterionTest._get_targetc                 C   s   |  ddS r  r[  r6  r&   r&   r'   r>   T  s    zCriterionTest.constructor_argsc                 C   s   |  ddS )NrM  Fr[  r6  r&   r&   r'   rM  X  s    zCriterionTest.extra_args)N)r`   ra   rb   rL  rQ  unionrV  re  rH   r  rf  r>   rM  r  r&   r&   rp  r'   r    s   %
3
r  c                 C   s   t j|t j|dd}|d k	r:t j|t j|d|   }||}t j||d}|| | }	|	   }
| }|	|
}|| | j
||||dd | j
|jj|
jj||dd d S )NT)r   r  rz  r  )r  Fr  )r1   r|   r  rp   r  r  r  r  r  rD  rs  r7  r   )rd  opr  Zinp_dimsprecr  r  Zout1Zgrad_input1Zop_bfp16r  Zgrad_input2Zout2r&   r&   r'   _test_bfloat16_ops]  s    

r  c                 C   s   |s| d ||}|s.t|}|| |rF| | |  |s| D ] }|jrR| |jt	|j qR| |jt	| d S )NT)
r  r1   Z	rand_liker  rs  rV   rR  rz  r7  r;  )rd  rS  rN  Z
check_sizeZ	inferencerz  ZgOr8   r&   r&   r'   _test_module_empty_inputp  s    


r  c                     sJ   G dd dt j G  fdddt j}   }|  }t ||}|||fS )Nc                       s   e Zd Z fddZ  ZS )z _create_basic_net.<locals>.Layerc              	      s:   t    ttdd| _| dtdddd d S )NrL   rJ   Zlayer_dummy_bufr/   rh  )	rk  rV  rd   rK  r1   r   Zlayer_dummy_paramregister_bufferr  r6  rp  r&   r'   rV    s    
z)_create_basic_net.<locals>.Layer.__init__r`   ra   rb   rV  r  r&   r&   rp  r'   Layer  s   r  c                       s   e Zd Z fddZ  ZS )z_create_basic_net.<locals>.Netc              	      sB   t      | _ttdd| _| dt	dddd d S )NrL   rJ   Z	dummy_bufrh  r/   )
rk  rV  l1rd   rK  r1   r   Zdummy_paramr  r  r6  )r  rq  r&   r'   rV    s    
z'_create_basic_net.<locals>.Net.__init__r  r&   r  rp  r'   Net  s   r  )rd   re   Z
Sequential)r  lr  r]  r&   r  r'   _create_basic_net  s    r  )r  F)Nr  r  )Nr  r  )Nr  r  r  )Nr  r  r  )Nr  r  )r  r  )r  r  )r  )r  r  )r  )r/   r/   Nr  )r   r  )r  rG   r  Fr  )r   r  )r   r  )r&   r   N)TF)abcr   ru  r  copyr   	functoolsr   r   	itertoolsr   operatorr   r1   Z
torch.cudaZtorch.nnrd   Ztorch.nn.functionalZ
functionalrj   r   r!   Z$torch.testing._internal.common_utilsr	   r
   r   r   r   r   r   r   Z#torch.testing._internal.common_cudar   r   Ztorch.autograd.gradcheckr   r   Ztorch.autogradr   Ztorch.typesr   Ztorch.backends.cudnntypingr   r   r   r   r   r   r   rv  rI  r(   r+   r}   rU   Zmodule_testsrX   rf   r~   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r  r  r  r  r   r!  r'  r+  r-  r5  r6  r<  r=  r>  r@  rA  rB  rT  ZinterpolateZsoftmaxr  r  ZgeluZreluSizeZnew_module_testsr  Zpadding_modeZcpp_padding_moder  rL  r  r`  joinmapr   Zcpp_paddingr@   r  r  Znon_linear_activations_no_batchr  __annotations__Znon_linear_activationZactivation_test_inforl  Z
extra_infoupdater  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r
  r   Zcriterion_testsr  Zregression_criterion_no_batchZ
reductionsr  r   Zregression_test_infoZ!classification_criterion_no_batchr  Zclassification_cpp_parityrx   r  Zclassification_test_infor"  rL  rg  r  r  r  r  r  r  r&   r&   r&   r'   <module>   s   ($>
	<

	










					
		



		
		

					
         

   						
		
			                       
 
                    




















     (  
 
&#    	
	 
	
{M G Iy