U
    Mh                 "   @   s0  U d dl Z d dlZd dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlZd dlmZ d dlmZ d dlmZmZmZmZmZmZmZmZmZmZmZmZmZm Z  d dl!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' d d	l(m)Z)m*Z* d d
l+m,Z, d dl-m.Z.m/Z/m0Z0m1Z1 d dl2m3Z3 G dd dZ4G dd deZ5G dd dZ6G dd dZ7G dd de.Z8dd Z9dddZ:dd Z;dddZ<dd  Z=dd!d"Z>d#d$ Z?dd%d&Z@d'd( ZAdd)d*ZBd+d, ZCdd-d.ZDd/d0 ZEdd1d2ZFd3d4 ZGdd5d6ZHd7d8 ZIdd9d:ZJd;d< ZKdd=d>ZLd?d@ ZMddAdBZNdCdD ZOddEdFZPdGdH ZQddIdJZRdKdL ZSeeTejUf eTdMdNdOZVdee4 dQdRdSZWe7ee:e;dTe,e0dUdVdWejXdXk oejXdYkdZe,e0d[dVd\e,e0d]dVd^e,e*ejYe)d_d`daidbdce,e0dddVdefdfe7ee<e=dgdhdidjd d dkdldm dndm gfe,e*ejZe)dododpej[e)dododpidVdqfe,e/dVdrdsdm dZe,e0dUdVdWejXdXk oejXdYkdZe,e0d[dVd\e,e0d]dVd^e,e0dddVdefdt	e7ee>dudm gdvdm gdwdm dxdm gdydm dzdm gd{dm gd|dm gd}dm d~dm gfe?dgde,e*ej\e)dddpidVdqddm dZe,e*ejZe)dododpej[e)dododpidVdqe,e*ejYe)dddpiddfe,e/dVdrddm dZe,e0dUdVdWejXdXk oejXdYkdZe,e0ddVd\e,e0d]dVd^e,e0dddVdefde7ee@eAdTe,e/dVdrddm dZe,e0dUdVdWejXdXk oejXdYkdZe,e0ddVde,e0d[dVd\e,e0d]dVd^e,e]ddVde,e0dddVdefdfe7eeBeCdgde,e*ej\e)dddpidVdqddm dZe,e*ejZe)dododpej[e)dododpidVdqe,e*ejYe)dddpiddfe,e/dVdrddm dZe,e0dUdVdWejXdXk oejXdYkdZe,e0ddVd\e,e0d]dVd^e,e0dddVdefde7eeDeEdTe,e0dUdVdWejXdXk otejXdYkdZe,e0ddVd\e,e0d]dVd^e,e*ejYe)dddpidVde,e0dddVdee,e]ddVdfdfe7eeFeGdPdidde,e/dVde,e*ej^e)dd`daidVd^e,e]ddVde,e]ddVde,e]ddVde,e]ddVde,e0dUdVdWejXdXk ojejXdYkdZfde7eeHeIdTe,e/dVdrddm dZe,e0dUdVdWejXdXk oejXdYkdZe,e0ddVd\e,e0d]dVd^e,e0ddVde,e0dddVdefdfe7eeJeKdTe,e0dUdVdWejXdXk oejXdYkdZe,e0ddVd\e,e0d]dVd^e,e*ej\e)dddpidVde,e0dddVdefdfe7eeLeMdTe,e/dVdrddm dZe,e0dUdVdWejXdXk oejXdYkdZe,e0d[dVd\e,e0d]dVd^e,e*ejYe)dddpidVde1dZe,e0dddVdefdfe7eeNeOdTe,e/dVdrddm dZe,e0dUdVdWejXdXk 	o8ejXdYkdZe,e0d[dVd\e,e0d]dVd^e,e0dddVdefdfe7eePddm gddm gddm ddm gddm ddm ddm gddm gddm gddm ddm gfeQdgdiddd dd dddm gfde,e0dUdVdWejXdXk 
oejXdYkdZe,e0ddVd\e,e0d]dVd^e,e*ejYe)dddpidVde1dZe,e0dddVdefd	e7e eReSdPdiddig fde,e/dVe,e]ddVde,e0ddVde,e0ddVdWe,e]ddVde,e0ddVde,e0ddVdre,e0ddVde,e0ddVde,e0ddVde,e0ddVdfdgZ_ee7 e`d< G ddń dŃZadS )    N)deepcopy)Enum)AnyDictListTupleUnion)Tensor)	Parameter)AdadeltaAdagradAdamAdamaxAdamWASGDLBFGSNAdam	OptimizerRAdamRMSpropRpropSGD
SparseAdam)
ConstantLRExponentialLRLinearLRPolynomialLRReduceLROnPlateauStepLR)toltoleranceOverride)DecorateInfo)_TestParametrizer	skipIfMpsskipIfTorchDynamoTEST_WITH_TORCHDYNAMO)&_get_foreach_kernels_supported_devicesc                   @   sX   e Zd ZdZdddgZdeee ee e	e
e
f f e	ee
f edddZd	d
 ZdS )OptimizerInputz@Contains args / kwargs to be passed to an optimizer constructor.paramskwargsdesc r(   r)   r*   c                 C   s   || _ || _|| _d S Nr,   )selfr(   r)   r*    r/   [/var/www/html/venv/lib/python3.8/site-packages/torch/testing/_internal/common_optimizers.py__init__6   s    zOptimizerInput.__init__c                 C   s   d| j  d| j d| j S )Nzparams=z	, kwargs=z, desc=r,   r.   r/   r/   r0   __repr__A   s    zOptimizerInput.__repr__N)r+   )__name__
__module____qualname____doc__	__slots__r   r   r
   r	   r   r   strr1   r3   r/   r/   r/   r0   r'   1   s   
 
r'   c                   @   s   e Zd ZdZdZdZdS )OptimizerErrorEnumz;Enumerates when an error is raised when testing optimizers.r      N)r4   r5   r6   r7   CONSTRUCTION_ERRORZ
STEP_ERRORr/   r/   r/   r0   r:   E   s   r:   c                   @   s0   e Zd ZdZddddgZejedddd	Zd
S )ErrorOptimizerInputz
    An OptimizerInput that will cause the optimizer to throw an error when constructed.
    Includes the type and string of the resulting error.
    optimizer_error_inputerror_on
error_typeerror_regexr+   )r?   r@   rA   c                C   s   || _ || _|| _|| _d S r-   )r>   r?   r@   rA   )r.   r>   r?   r@   rA   r/   r/   r0   r1   T   s    zErrorOptimizerInput.__init__N)	r4   r5   r6   r7   r8   r:   r<   RuntimeErrorr1   r/   r/   r/   r0   r=   L   s   r=   c                   @   sz   e Zd ZdZdd dd gfdddi g fdddddd	d	dd
eee eeeeeeee d	ddZdd Z	e
dd Zd	S )OptimizerInfoz,Optimizer information to be used in testing.c                 C   s   t | dddS N?
   gammaZ	step_sizer   optr/   r/   r0   <lambda>t       zOptimizerInfo.<lambda>c                 C   s   t | S r-   r   rJ   r/   r/   r0   rL   u   rM   foreachdifferentiableFTr/   N)scheduler_inputssupported_implssupports_sparseonly_supports_sparse_gradsmetadata_for_sparsesupports_complexstep_requires_closuresupports_param_groupssupports_multiple_devicesskips
decoratorsoptim_error_inputs_funcsupports_fused_on)		optim_clsrS   rT   rU   rW   rX   rY   rZ   r^   c                C   sl   || _ || _|| _|| _|| _|| _|| _|| _|	| _|
| _	|| _
|rJ|ng |rT|ng | _|| _|| _d S r-   )r_   optim_inputs_funcrR   rS   rT   rV   rU   rW   rX   rY   rZ   r\   r]   r^   )r.   r_   r`   rR   rS   rT   rU   rV   rW   rX   rY   rZ   r[   r\   r]   r^   r/   r/   r0   r1   e   s     +

zOptimizerInfo.__init__c                 C   sH   g }| j D ]8}t|tr8||||||rB||j  q
|| q
|S r-   )r\   
isinstancer!   Z	is_activeextendappend)r.   Z
test_class	test_namedevicedtypeparam_kwargsresult	decoratorr/   r/   r0   get_decorators   s    

    zOptimizerInfo.get_decoratorsc                 C   s   | j jS r-   )r_   r4   r2   r/   r/   r0   name   s    zOptimizerInfo.name)r4   r5   r6   r7   r   r   r9   boolr1   rj   propertyrk   r/   r/   r/   r0   rC   b   s>   =rC   c                   @   s"   e Zd ZdZdddZdd ZdS )optimszGDecorator for specifying a list of optimizers over which to run a test.Nc                 C   s$   t || _|d k	r|ntjg| _d S r-   )listoptim_info_listtorchfloat32dtypes)r.   Zoptim_info_iterablers   r/   r/   r0   r1      s    
zoptims.__init__c                 #   s   |d krt dt| j| jD ]\}}|j}||d}zBt  fdd}t|j	|j
 j
|j|}	||||	fV  W q  tk
r }
 z td| d|j d |
W 5 d }
~
X Y q X q d S )NzThe @optims decorator is only intended to be used in a device-specific context; use it with instantiate_device_type_tests() instead of instantiate_parametrized_tests())
optim_inforf   c                     s
    | |S r-   r/   )argsr)   testr/   r0   test_wrapper   s    z.optims._parametrize_test.<locals>.test_wrapperzFailed to instantiate z for module !)rB   	itertoolsproductrp   rs   rk   	functoolswrapspartialrj   r4   Zdevice_type	Exceptionprint)r.   rw   Zgeneric_clsZ
device_clsrt   rf   rd   rg   rx   Zdecorator_fnexr/   rv   r0   _parametrize_test   s.    
zoptims._parametrize_test)N)r4   r5   r6   r7   r1   r   r/   r/   r/   r0   rn      s   
rn   c                 C   s~   t | dkrvttjd| |d}tt|i ddtddtt||gi ddtd	dttd
|id
|igi ddtddgS g S d S )Ncpur;   re   rf   zinvalid param typer,   zPparams argument given to the optimizer should be an iterable of Tensors or dictsr@   rA   z.a param group cannot have duplicate parametersz/.*a parameter group with duplicate parameters.*r(   z@duplicate parameters should not occur across param groups eitherz7some parameters appear in more than one parameter group)	r9   r
   rq   Zrandnr=   r'   	TypeErrorUserWarning
ValueError)re   rf   Zsample_paramr/   r/   r0   get_error_inputs_for_all_optims   s>    		r   c              	   C   s   t d ddiddt d dddddt d tdddd	dg}t d i d
dt d ddiddt d ddiddt d dddddt d dddddgdt| kr|ng  S )N
capturableTr,   皙?weight_decayr   zcapturable with weight decayMbP?lrr   Tensor lr with capturabledefaultr   {Gz?non-default lrr   nonzero weight_decayr   maximizer   gffffff?rE   )rhor   r   cudar'   rq   tensorr9   re   rf   cuda_supported_configsr/   r/   r0   optim_inputs_func_adadelta$  s@        r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r   g?)r   r   zrho should be between 0 and 1r,   zInvalid rho value: 1.1r   r   r9   r=   r'   dictr   re   rf   error_inputsr/   r/   r0    optim_error_inputs_func_adadeltaD  s    

r   c                 C   s~   t d i ddt d ddiddt d dddddt d d	did
dt d dddddt d ddddddt d d	tdiddgS )Nr   r,   r   r   r   Tr   r   r   r   )initial_accumulator_valuer   r         ?)r   lr_decayr   r   r   z	Tensor lrr'   rq   r   r   r/   r/   r0   optim_inputs_func_adagradU  s8      
r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r         )r   r   zlr_decay must be bigger than 0r,   zInvalid lr_decay value: -0.5r   r   r   r/   r/   r0   optim_error_inputs_func_adagrads  s    

r   c              	   C   s   t d ddiddt d ddddddt d tddddd	dg}t d i d
dt d ddiddt d ddiddt d dddddt d dddddgdt| kr|ng  }|tjfkr|D ]}d|jd< q|S )Nr   Tr,   r   )r   amsgradr   zcapturable, amsgradr   )r   r   r   z%Tensor lr with capturable and amsgradr   r   r   r   r   r   r   r   )r   r   r   r   eps)r'   rq   r   r9   float16r)   )re   rf   r   totalinputr/   r/   r0   optim_inputs_func_adam  sH    
    
r   c              
   C   s   t | |}t| dkr~|ttd tdddddtddttd tdd	d
ddtddttd ttdddddtddg7 }dt| krtjd| |d}|tt|gdddddt	ddtt|gdddddt	ddg7 }|S )Nr   r         ?        r   Zbetasbeta1 should be between 0 and 1r,   &Invalid beta parameter at index 0: 1.0r   r   r   weight_decay should > 0Invalid weight_decay value: -1r   T)r   rP   z7lr as Tensor doesn't work with foreach & not capturablezElr as a Tensor is not supported for capturable=False and foreach=Truer   r/   r   )rP   fusedz/`fused` and `foreach` cannot be `True` together)r   rQ   z)`fused` does not support `differentiable`)
r   r9   r=   r'   r   r   rq   r   emptyrB   )re   rf   r   Zsample_tensorr/   r/   r0   optim_error_inputs_func_adam  sl    

	
		r   c                 C   s   t d ddiddt d ddddddt d ddddddt d dd	ddd
dt d tddd	ddddg}t d i ddt d ddiddt d ddiddt d dddddgdt| kr|ng  S )Nr   Tr,   rE   r   r   r   z"capturable, maximize, weight_decayr   capturable, maximizeFcapturable, weight_decayr   r   r   r   r   z#capturable, weight_decay, tensor LRr   r   r   r   r   r   r   r   r   r   r   r/   r/   r0   optim_inputs_func_adamax  sR    


  r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r   )r   r   r   zbeta2 should be between 0 and 1r,   z&Invalid beta parameter at index 1: 1.0r   r   r   r/   r/   r0   optim_error_inputs_func_adamax  s    

r   c                 C   s
   t | |S r-   )r   r   r/   r/   r0   optim_inputs_func_adamw)  s    r   c                 C   s
   t | |S r-   )r   r   r/   r/   r0   optim_error_inputs_func_adamw-  s    r   c                 C   s   t d ddiddt d dddddt d dddddt d dddd	d
dt d tdddddddg}t d i ddt d ddiddt d ddiddt d ddiddt d ddiddt d ddiddt d dddddgdt| kr|ng  S )Nr   Tr,   )r   r   zmaximize, capturabler   r   weight_decay, capturabler   z"maximize, weight_decay, capturabler   r   z-maximize, weight_decay, capturable, tensor LRr   lambdznon-default lambdr   g{Gz?r   t0d   r   r   r   r   zmaximize, nonzero weight_decayr   r   r   r/   r/   r0   optim_inputs_func_asgd1  sX    
  r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r   r   r   r   r,   z Invalid weight_decay value: -0.5r   r   r   r/   r/   r0   optim_error_inputs_func_asgd_  s    

r   c                 C   s@   t d i ddt d ddiddt d ddiddt d dd	id	dgS )
Nr   r,   r   r   r   Ztolerance_gradư>Zline_search_fnZstrong_wolfer'   r   r/   r/   r0   optim_inputs_func_lbfgsp  s      r   c                 C   s   t | |}|S r-   )r   r   r/   r/   r0   optim_error_inputs_func_lbfgs  s    
r   c                 C   s   t d ddiddt d ddddddt d dddddd	dt d td
dddddd	dg}t d i ddt d dd
iddt d ddiddt d dddddt d ddddddt d dddddgdt| kr|ng  S )Nr   Tr,   rE   g~jtx?)r   momentum_decayr   r   )r   r   decoupled_weight_decayr   z"decoupled_weight_decay, capturabler   )r   r   r   r   r   r   r   r   r   znon-zero momentum_decayr   )r   r   r   )r   r   r   r   r   r   r   r   r   r/   r/   r0   optim_inputs_func_nadam  sl    

	r   c              	   C   s^   t | |}t| dkrZ|ttd tdddddtddttd tdd	d
ddtddg7 }|S )Nr   r   r   r   r   r,   r   r   gɿ)r   r   zmomentum_decay should > 0z"Invalid momentum_decay value: -0.2r   r   r/   r/   r0   optim_error_inputs_func_nadam  s.    

	
r   c              
   C   s   t d ddiddt d dddddt d ddddddt d td	dddd
ddg}t d i ddt d ddiddt d ddiddt d ddiddt d dddddt d dddddgdt| kr|ng  S )Nr   Tr,   r   )r   r   r   )r   r   r   z0capturable, weight_decay, decoupled_weight_decayr   )r   r   r   r   z;capturable, weight_decay, decoupled_weight_decay, tensor LRr   r   gMb`?r   r   r   znon-default epsr   r   )r   r   r   r   r   r   r   r   r/   r/   r0   optim_inputs_func_radam  s^    	  r   c              	   C   s^   t | |}t| dkrZ|ttd tdddddtddttd tdd	d
ddtddg7 }|S )Nr   r   r   r   r   r,   r   r   r   r   r   r   r   r   r/   r/   r0   optim_error_inputs_func_radam  s.    

	
r   c                 C   s   t d ddiddt d ddddddt d tdddd	dg}t d i d
dt d ddiddt d ddiddt d dddddt d ddddddt d dddddddgdt| kr|ng  S )Nr   Tr,   r   r   r   r   r   r   r   r   r   r   r   )r   centeredr   )r   r   momentumr   )r   r   r   r   r   r   r   r   r/   r/   r0   optim_inputs_func_rmsprop(  sR    
  
r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r   g      r   r   "momentum should be between 0 and 1r,   zInvalid momentum value: -1.0r   r   r   r/   r/   r0   optim_error_inputs_func_rmspropT  s    

r   c              	   C   s   t d ddiddt d tdddddg}t d i ddt d dd	id
dt d ddiddt d ddiddt d ddiddgdt| kr|ng  S )Nr   Tr,   r   r   r   r   r   g-C6*?r   etas)r   g      ?znon-default etasZ
step_sizes)g>r   znon-default step_sizesr   r   r   r   r/   r/   r0   optim_inputs_func_rprope  s.    
  r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r   )r   r   )r   r   z0 < eta1 < 1 < eta2r,   zInvalid eta values: 1.0, 0.5r   r   r   r/   r/   r0   optim_error_inputs_func_rprop~  s    

r   c                 C   s   t d i ddt d ddiddt d dtdiddt d dd	iddt d d	d
dddt d d	ddddt d d	dddddt d dddddgS )Nr   r,   r   r   r   r   z	tensor lrr   rE   r   )r   	dampeningr   r   )r   r   znon-zero weight_decayT)r   nesterovr   r   r   r   r   r   r/   r/   r0   optim_inputs_func_sgd  s:      
r   c                 C   s@   t | |}t| dkr<|ttd tdddddtddg7 }|S )	Nr   r   r   r   r   r,   zInvalid momentum value: -0.5r   r   r   r/   r/   r0   optim_error_inputs_func_sgd  s    

r   c                 C   s0   t d i ddt d ddiddt d ddiddgS )Nr   r,   r   r   r   r   Tr   r   r/   r/   r0   optim_inputs_func_sparseadam  s      r   c                 C   s   t | |}t| dkr|ttd tdddddtddtttjd	tj| |d
gi ddtddttdtjd	tj| |d
gigi ddtddtttj	dd	| tj
dgt ddtddg7 }|S )Nr   r   r   r   r   r,   r   r      )Zlayoutre   rf   zdense params requiredz+SparseAdam requires dense parameter tensorsr(   z%dense params required in param_groups   r   zcomplex not supportedz.SparseAdam does not support complex parameters)r   r9   r=   r'   r   r   rq   ZzerosZ
sparse_cooZrand	complex64r   r/   r/   r0   "optim_error_inputs_func_sparseadam  sp    

	   	5r   )re   returnc                 C   s2   t | tjrt| j} t | ts$t| dd S )N:r   )ra   rq   re   r9   typeAssertionErrorsplit)re   r/   r/   r0   _get_device_type  s    
r   r/   )r   c              
      s   t dd D std }t fddjD }g }|D ]}t|j}t|dkr|D ]}	d||	< qf|t	d||j
d n
|| |D ]4}	t|}
d	|
|	< |t	d|
|j
 d
|	 d qqH|S )a~  
    Return a list of all configs for a given optimizer as a list of OptimizerInputs,
    including configs that have supported global cliquey kwargs (foreach, fused,
    differentiable) based on optim_info.supported_impls.

    The configs (optim_inputs) returned by optim_info.optim_inputs_func(...)
    intentionally do NOT include global cliquey kwargs to give flexibility to tests.
    For example, testing correctness between toggling foreach on and off is now
    trivial. That said, we sometimes want to test for all possible configs on an
    optimizer including all supported flags, so this helper returns all optim inputs.
    c                 s   s   | ]}|d kV  qdS ))rP   r   rQ   Nr/   .0xr/   r/   r0   	<genexpr>  s    zD_get_optim_inputs_including_global_cliquey_kwargs.<locals>.<genexpr>z?skip must be a subset of ['foreach', 'fused', 'differentiable']c                 3   sF   | ]>}|krt  jks$|d krt  t ks:|dkr|V  qdS )r   rP   N)r   r^   r&   r   re   rt   skipr/   r0   r   !  s    r   FNr,   Tz & )allr   r`   tuplerS   r   r)   lenrc   r'   r*   )re   rf   rt   r   Zoptim_inputsrS   Zall_optim_inputsZoptim_inputZbase_kwargsflagZ
new_kwargsr/   r   r0   1_get_optim_inputs_including_global_cliquey_kwargs  s<    




  r   rO   z,Fails fix point assertion on 3.8, see #97811ZTestOptimRenewedZtest_tensor_lr)r   	   )r      )Z	active_ifzSee #116028Z)test_set_default_dtype_works_with_foreachzPAccessing grad.real errors, see https://github.com/pytorch/pytorch/issues/117184Ztest_complex_2dg/nB?g-C6
?)rtolatolZCompiledOptimizerParityTestsZtest_correctnessz3This test uses mocks, which dynamo does not supportZ test_defaults_changed_to_foreach)r`   r]   rS   r[   )rP   rQ   r   )r   Tr   )r   r   r   c                 C   s   t | dddS )NwJ?i  rG   rI   rJ   r/   r/   r0   rL     rM   rL   c                 C   s   t | ddS )Ng-C6?)	thresholdrN   rJ   r/   r/   r0   rL     rM   g{Gzt?)r   r   Ztest_fused_matches_forloopZ!test_forloop_goes_right_directionc                 C   s
   | d  S N
contiguousr/   r)   r/   r/   r0   rL     rM   )r`   r]   rS   r^   rT   rV   r\   r[   c                 C   s   t | ddS NrE   rH   r   rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS )N皙?   )start_factortotal_itersr   rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS Nr  r  )factorr  r   rJ   r/   r/   r0   rL     rM   c                 C   s   t | ddS r  r  rJ   r/   r/   r0   rL     rM   c                 C   s   t | ddS r  r  rJ   r/   r/   r0   rL     rM   c                 C   s   t | S r-   rN   rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS r
  r  rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS NrE   r  )powerr  r   rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS rD   rI   rJ   r/   r/   r0   rL     rM   c                 C   s   t | S r-   rN   rJ   r/   r/   r0   rL     rM   )r   r   g1E2>gavt>c                 C   s   t o| d tjkS Nrf   r%   rq   float64r  r/   r/   r0   rL     s   giUMu>gkNuϵ>ZTestCudaOptimsZ+test_grad_scaling_autocast_fused_optimizersc                 C   s
   | d  S r   r/   r  r/   r/   r0   rL     rM   zTErrors w/ Global state changed, see https://github.com/pytorch/pytorch/issues/116028)r`   rR   r]   rS   r^   r\   r[   c                 C   s
   | d  S r   r/   r  r/   r/   r0   rL   &  rM   z/Mismatched _foreach_addcdiv_ types, see #118159Ztest_complexz2Uses too much memory, even for H100, surprisingly.Ztest_foreach_large_tensorc                 C   s   t o| d tjkS r  r  r  r/   r/   r0   rL   Y  s   c                 C   s
   | d  S r   r/   r  r/   r/   r0   rL     rM   )r`   r]   rS   r^   r\   r[   giUMu>gh㈵>Z test_step_is_noop_for_zero_gradsz7ASGD internally changes the weights even with zero gradFZtest_can_load_older_state_dictg9̗?zDoes not support param groupsZtest_param_groups_lrZtest_param_groups_weight_decayz!LBFGS doesn't support multideviceZ*test_forloop_goes_right_direction_multigpuZ6test_param_group_with_lrscheduler_goes_right_direction)r`   r]   rS   rX   rY   rZ   r[   c                 C   s
   | d  S r   r/   r  r/   r/   r0   rL     rM   z8Errors, https://github.com/pytorch/pytorch/issues/117150Ztest_load_nontensor_stepgv!>gw$}>Ztest_foreach_matches_forloopc                 C   s
   | d  S r   r/   r  r/   r/   r0   rL   m  rM   gMb@?r   Ztest_mixed_device_dtypec                 C   s
   | d  S r   r/   r  r/   r/   r0   rL     rM   c                 C   s   t | dddS rD   rI   rJ   r/   r/   r0   rL     rM   c                 C   s   t | ddddS )Nr  g?r  r  Z
end_factorr  r	  rJ   r/   r/   r0   rL     s
      c                 C   s   t | dddS rD   rI   rJ   r/   r/   r0   rL     rM   c                 C   s   t | ddddS )Nr  g333333?r  r  r	  rJ   r/   r/   r0   rL     s
      c                 C   s   t | dddS )NGz?rF   rG   rI   rJ   r/   r/   r0   rL     rM   c                 C   s   t | ddS )Nr  r  r  rJ   r/   r/   r0   rL     rM   c                 C   s   t | S r-   rN   rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS r
  r  rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS r  r  rJ   r/   r/   r0   rL     rM   c                 C   s   t | dddS rD   rI   rJ   r/   r/   r0   rL     rM   c                 C   s   t | S r-   rN   rJ   r/   r/   r0   rL     rM   ga2U0*s?)r   r   r   r   r   c                 C   s   t | dddS )Nr   i,  rG   rI   rJ   r/   r/   r0   rL     rM   gy&1|?)r`   rR   r]   rS   rT   rV   r^   r[   r   g{Gz?z8SparseAdam does not support dense gradients, see #116507Ztest_state_dict_deterministicz,cannot call to_sparse on p.grad, see #117184Z test_state_dict_with_cuda_paramsZ%test_deepcopy_copies_all_public_attrs)r`   r]   rS   rU   rV   rW   r[   optim_dbc                   @   s2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )TensorTrackera0  
    A utility to track tensor clones in a list, with the expectation of popping them later (in
    order) to make fair comparisons between two multi-step computation. The intended use case is
    usually when comparing two supposed equal computations, such as an optimizer step that each
    individually consists of multiple steps, where numerical deviation could multiply.

    The goal is to be able to compare and align numbers at every milestone so as to minimize
    numerical discrepancies, and so when the test fails, it is likely a real problem.
    Nc                 C   s   |d kri }|| _ g | _d S r-   )assert_eq_kwargstensors)r.   r  r/   r/   r0   r1   c  s    zTensorTracker.__init__c                 C   s   | j |   dS )z@
        Add a clone().detach()'d version of the tensor
        N)r  rc   clonedetach)r.   r   r/   r/   r0   addi  s    zTensorTracker.addc              	   C   sp   | t| jdd | jd}|t|tdt| |j||f| j	 t
  || W 5 Q R X dS )z
        Pop the first element in the tensor tracker, assert equality between the popped tensor and
        the input tensor, and then set the input tensor to have the same values as the popped tensor
        (with copy_).
        r   zno tensors to popz
type(ref)=N)ZassertGreaterr   r  popZ
assertTruera   r	   r   ZassertEqualr  rq   Zno_gradZcopy_)r.   Ztensor_to_setZtestcaserefr/   r/   r0   pop_check_setp  s    
zTensorTracker.pop_check_setc                 C   s   t | jdkS )Nr   )r   r  r2   r/   r/   r0   
all_popped  s    zTensorTracker.all_popped)N)r4   r5   r6   r7   r1   r  r  r  r/   r/   r/   r0   r  X  s
   

r  )N)N)N)N)N)N)N)N)NN)N)N)N)N)r/   )br|   rz   sysZunittestcopyr   enumr   typingr   r   r   r   r   rq   r	   Ztorch.nnr
   Ztorch.optimr   r   r   r   r   r   r   r   r   r   r   r   r   r   Ztorch.optim.lr_schedulerr   r   r   r   r   r   Z*torch.testing._internal.common_device_typer   r    Z2torch.testing._internal.common_methods_invocationsr!   Z$torch.testing._internal.common_utilsr"   r#   r$   r%   Ztorch.utils._foreach_utilsr&   r'   r:   r=   rC   rn   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r9   re   r   r   version_inforr   Zbfloat16r   r  r   r   r  __annotations__r  r/   r/   r/   r0   <module>   s"   @ Q3>
 

-9
,

.

<
3
,



<	 ; 3
 
 
A 
 
 
%`1 
 
 (Q 
	3   50 

- 

1'
 

U
           