U
    zhh                     @   s  d dl mZmZmZ d dlZd dlm  mZ d dl	m
Z
mZmZmZ d dlmZ d dlmZ d dlmZmZmZmZmZ d dlmZ d dlmZ G d	d
 d
eZe Zeejeej dd Z!dd Z"dZ#e#fddZ$G dd deZ%dd Z&dd Z'eej(dd Z)dd Z*eej+dd Z,dd Z-d d! Z.G d"d# d#Z/G d$d% d%e/Z0G d&d' d'e/Z1d0d(d)Z2d1d*d+Z3d,d- Z4G d.d/ d/eZ5e5 Z6dS )2    )Any
NamedTupleTupleN)_unwrap_for_grad_wrap_for_gradcurrent_levelTransformType)vmap)%enable_single_level_autograd_function)_add_batch_dim_broadcast_to_and_flattenrestore_vmapunwrap_batchedwrap_batched)HigherOrderOperator)_set_fwd_grad_enabledc                       s(   e Zd Z fddZ fddZ  ZS )!CustomFunctionHigherOrderOperatorc                    s   t  d d S )Ncustom_function_callsuper__init__self	__class__ T/var/www/html/venv/lib/python3.8/site-packages/torch/_functorch/autograd_function.pyr   !   s    z*CustomFunctionHigherOrderOperator.__init__c                    s*   t j rt j|f||S |j||S N)torchZ_CZ _are_functorch_transforms_activer   __call__apply)r   autograd_functionargskwargsr   r   r   r   $   s    
z*CustomFunctionHigherOrderOperator.__call____name__
__module____qualname__r   r   __classcell__r   r   r   r   r       s   r   c              	   G   s*   t | |}t  |j| }W 5 Q R X |S r   )generate_single_level_functionr
   r    )interpreterr!   operands	GeneratedZflat_outr   r   r   custom_function_call_gradX   s    
r-   c                    sz      fdd} fdd} fdd} fdd} j d	}t|tjjjft|t|t|t|d
}|S )Nc                     s|   t tjfdd| }t < td(   t f| }W 5 Q R X W 5 Q R X W 5 Q R X fdd}t||| |S )Nc                    s
   t |  S r   )r   )xlevelr   r   <lambda>f       zAgenerate_single_level_function.<locals>.forward.<locals>.<lambda>Tc                    s
   t |  S r   )r   )outputr/   r   r   wrap_fnq   s    z@generate_single_level_function.<locals>.forward.<locals>.wrap_fn)	pytreeZtree_map_onlyr   TensorZenable_gradr   lowerr   !wrap_outputs_maintaining_identity)r+   unwrapped_operandsunwrapped_outputr4   r!   r*   r0   r   r   forwardd   s$     
 $   z/generate_single_level_function.<locals>.forwardc                    s     | ||S r   )setup_context)ctxinputsr3   r!   r   r   r=   x   s    z5generate_single_level_function.<locals>.setup_contextc                    s    j | f| }|S r   )backward)r>   Zgradsresultr@   r   r   rA   |   s    z0generate_single_level_function.<locals>.backwardc                    s    j | f| }|S r   )jvp)r>   tangentsrB   r@   r   r   rC      s    z+generate_single_level_function.<locals>.jvpr,   )r<   rA   rC   r=   )r0   r%   typer   autogradfunctionZ_SingleLevelFunctionstaticmethod)r*   r!   r<   r=   rA   rC   namer,   r   r;   r   r)   a   s     

r)   znot specifiedc                 C   s   t j| }t j| }dd t||D }t | \}}	g }
|tk}|r~t||	}|d kr~td| dt |d  d|	 dt|D ]h\}}t|t	j
s|
| qt||kr|
|t|  q|r|
||||  q|
|| qt |
|	S )Nc                 S   s   i | ]\}}t ||qS r   )id).0Z	unwrappedorigr   r   r   
<dictcomp>   s    z5wrap_outputs_maintaining_identity.<locals>.<dictcomp>zoThe autograd.Function's vmap staticmethod returned an incompatible (output, out_dims) tuple. Expected out_dims=zI to be compatible with the structure of `output`. out_dims has structure    z but output has structure zV. For more details, please see https://pytorch.org/docs/main/notes/extending.func.html)r5   arg_tree_leavesziptree_flattenNO_OUT_DIMSr   RuntimeError	enumerate
isinstancer   r6   appendrJ   tree_unflatten)outputsZunwrapped_inputsZorig_inputsr4   out_dimsZflat_unwrapped_inputsZflat_orig_inputsZunwrapped_input_to_orig_inputZflat_outputsspecrB   Zout_dims_specifiedZflat_out_dimsir3   r   r   r   r8      s2    


 
r8   c                   @   s   e Zd ZU eed< eed< dS )VmapInfo
batch_size
randomnessN)r%   r&   r'   int__annotations__strr   r   r   r   r\      s   
r\   c                 C   s   | j tjjj k	S r   )r	   r   rF   Functionr@   r   r   r   has_overriden_vmap_rule  s    rc   c                 C   sN   d}t | ts&t|dt|  d t| dksJt|dt|  d d S )Nz}Expected the vmap staticmethod to have two returns, an output and out_dims with pytree structure compatible with the output. zGot a z instead   zGot z returns instead)rU   tuplerS   rE   len)rB   Zbase_error_msgr   r   r   +validate_vmap_returns_tuple_of_two_elements	  s    
rg   c           
   
      s  |j r0t|r td|j dt| |f| S t|sJtd|j d|   t|  |  d}t	| \}}t
dd |r|   t|f| W  5 Q R  S Q R X |   |j||f| }W 5 Q R X t| |\}} fdd}	t||||	|d	S )
NzYou tried to vmap over a  , but it has both generate_vmap_rule=True and an overriden vmap staticmethod. Please set generate_vmap_rule=False or delete the overriden vmap staticmethod to avoid ambiguity. For more details, please see https://pytorch.org/docs/main/notes/extending.func.htmlz, but it does not have vmap support. Please override and implement the vmap staticmethod or set generate_vmap_rule=True. For more details, please see https://pytorch.org/docs/main/notes/extending.func.html)r]   r^   c                 S   s   | d kS r   r   )dimr   r   r   r1   ;  r2   z+custom_function_call_vmap.<locals>.<lambda>c                    s   |d kr| S t | | S r   )r   )r3   Zout_dimr   r   r   r4   E  s    
z*custom_function_call_vmap.<locals>.wrap_fn)rY   )generate_vmap_rulerc   rS   r%   'custom_function_call_vmap_generate_ruler0   r\   r]   r^   r   r5   Ztree_allr7   r   r	   rg   r8   )
r*   r!   r+   infor9   in_dimsrB   r:   rY   r4   r   ri   r   custom_function_call_vmap  sH     
 
    rn   c           	   	   G   sd   t ||  \}}t|||  |  \}}|   t|f| }W 5 Q R X | }t|||  S r   )r   r0   vmapify_autograd_functionr]   r^   r7   r   r   )	r*   r!   r+   r9   rm   Zvmapped_functionget_out_dimsr3   rY   r   r   r   rk   Q  s       
rk   c                 G   s   t dd S )Nz0NYI: Functionalize rule for custom_function_call)rS   )r*   r!   rj   r+   r   r   r   "custom_function_call_functionalize^  s    rq   c              	      s   d fdd} fdd} fdd} fdd	}d
 j  }t|tjjft|t|t|t|dd}	fdd}
|	|
fS )Nznot populatedc                     s   t  j|  \}|S r   )r   r<   )r+   rX   )r!   r]   rm   rY   r^   r   r   r<   s  s       z*vmapify_autograd_function.<locals>.forwardc                    s>   d d  fdd}t |f|| 	d S )Nc                    s6   t t } || | tdd | D |jd S )Nc                 s   s$   | ]}t |tjr|jnd V  qd S r   )rU   r   r6   shape)rK   Zinpr   r   r   	<genexpr>  s    zRvmapify_autograd_function.<locals>.setup_context.<locals>.inner.<locals>.<genexpr>)CtxCustomSaver   r=   re   _pt_saved_tensors_bdims)r?   rX   wrapped_ctx)r!   r>   input_shapes_saved_tensors_bdims_r   r   inner~  s    z?vmapify_autograd_function.<locals>.setup_context.<locals>.inner)r   )r>   r?   rX   ry   )r!   r]   rm   input_shapesrY   r^   saved_tensors_bdims)r>   rw   rx   r   r=   z  s     z0vmapify_autograd_function.<locals>.setup_contextc                    s`   kst kst  fdd}t|}t||f j|\}}t||}|S )Nc                    s   t | } j|f| S r   )CtxWithSavedTensorsrC   )saved_tensorsrD   rv   r!   r>   r   r   jvp_no_context  s    
z>vmapify_autograd_function.<locals>.jvp.<locals>.jvp_no_context)AssertionErrorget_tangents_in_dimsr   r}   	reductify)r>   rD   r   Ztangent_in_dimsZout_tangentsZout_tangents_dimsrB   )r!   r]   rm   init_valrY   r^   r{   r>   r   rC     s    
 z&vmapify_autograd_function.<locals>.jvpc                    sh   kst kst ks$t  fdd}t|ff j|f\}}t||}|S )Nc                    s"   | \}}t |} j|f| S r   )r|   rA   )r?   r}   grad_outputsrv   r~   r   r   backward_no_context  s    
zHvmapify_autograd_function.<locals>.backward.<locals>.backward_no_context)r   r   r}   r   )r>   r   r   Zgrad_insZgrad_ins_dimsrB   r!   r]   rm   r   rz   rY   r^   r{   r   r   rA     s    z+vmapify_autograd_function.<locals>.backwardZVmappedT)r<   rA   rC   r=   rj   c                      s    kst S r   )r   r   )r   rY   r   r   rp     s    z/vmapify_autograd_function.<locals>.get_out_dims)r%   rE   r   rF   rb   rH   )r!   rm   r]   r^   r<   r=   rC   rA   rI   r,   rp   r   r   r   ro   e  s*    	$ro   c                 C   s8   t | \}}t j| }dd t||D }t ||S )Nc                 S   s    g | ]\}}|d krd n|qS r   r   )rK   Zin_dimZtangentr   r   r   
<listcomp>  s   z(get_tangents_in_dims.<locals>.<listcomp>)r5   rQ   rO   rP   rW   )Z
input_dimsrD   Zflat_in_dimsrZ   Zflat_tangentsrB   r   r   r   r     s    
r   c                   @   s:   e Zd ZU dZeedf ed< dd Zdd Zdd	 Z	d
S )
WrappedCtx)_pt_reserved_attrs_pt_inner_ctx.r   c                 C   sD   t |ts:t| j}|D ] }t||s(qtd| dq|| _d S )NzPyTorch reserves the zU field on ctx. Please name your fields on ctx something else to avoid name collision.)rU   r   rE   r   hasattrrS   r   )r   r>   Zreserved_attrsrI   r   r   r   r     s    



zWrappedCtx.__init__c                 C   s   t | j|S r   )getattrr   )r   rI   r   r   r   __getattr__)  s    zWrappedCtx.__getattr__c                 C   s*   |t | jkr|| j|< d S t| j||S r   )rE   r   __dict__setattrr   )r   rI   valuer   r   r   __setattr__,  s    
zWrappedCtx.__setattr__N)
r%   r&   r'   r   r   ra   r`   r   r   r   r   r   r   r   r     s   
r   c                       s2   e Zd ZdejZ fddZedd Z  ZS )r|   _pt_new_saved_tensorsc                    s   t  | || _d S r   )r   r   r   )r   r>   Znew_saved_tensorsr   r   r   r   7  s    zCtxWithSavedTensors.__init__c                 C   s   | j S r   )r   r   r   r   r   r}   ;  s    z!CtxWithSavedTensors.saved_tensors)r   )	r%   r&   r'   r   r   r   propertyr}   r(   r   r   r   r   r|   4  s   
r|   c                       s6   e Zd Zd	ejZ fddZdd Zdd Z  ZS )
rt   ru   _pt_current_levelc                    s   t  | d| _|| _d S )Nr   )r   r   ru   r   )r   r>   r   r   r   r   r   G  s    zCtxCustomSave.__init__c                 G   s&   t || j\}}| jj|  || _d S r   )r   r   r   save_for_backwardru   r   ZtensorsZunwrapped_tensorsZbdimsr   r   r   r   L  s    zCtxCustomSave.save_for_backwardc                 G   s&   t || j\}}| jj|  || _d S r   )r   r   r   save_for_forwardru   r   r   r   r   r   Q  s    zCtxCustomSave.save_for_forward)ru   r   )	r%   r&   r'   r   r   r   r   r   r(   r   r   r   r   rt   @  s     rt   c                    sh   t | ts| f} t |ts |f}t |ts0|f}|d krDt| d }t fddt| |||D }|S )Nr   c                 3   s&   | ]\}}}}t ||| |V  qd S r   )reductify_leaf)rK   giZgi_bdimZi_bdimZmaybe_ishaper]   r   r   rs   g  s   
zreductify.<locals>.<genexpr>)rU   re   rf   rP   )
grad_inputgrad_input_bdim
input_bdimr]   &target_shape_without_bdim_to_reduce_torB   r   r   r   r   W  s"    


	r   c                 C   s   | d krd S |d kr |d kr | S |d k	r:|d kr:|  |S |d k	sFt|d krx| |} t| j}|||< | |} |}|d k	rttjj	|d f|d| |S ||kr| 
||} | S )N)rm   rY   )sumr   Z	unsqueezelistrr   expandr	   r   r6   Zsum_to_sizeZmovedim)r   r   r   r]   r   Z	new_shaper   r   r   r   s  s2    



 r   c                    s    fdd}|S )Nc                    s    ||}| || |S r   r   )r>   r"   r#   r3   original_forwardoriginal_setup_contextr   r   new_forward  s    
z8autograd_function_forward_rewritten.<locals>.new_forwardr   )r   r   r   r   r   r   #autograd_function_forward_rewritten  s    r   c                       s$   e Zd Z fddZdd Z  ZS )AutogradFunctionApplyc                    s   t  d d S )Nautograd_function_applyr   r   r   r   r   r     s    zAutogradFunctionApply.__init__c           	         sH   d |d }t |}d | }G  fdddtjj}|j| S )Nargs_tensor_maskc                       s2   e Zd ZefddZe fddZdS )z5AutogradFunctionApply.__call__.<locals>.ApplyTemplatec                    s    d \}|S N)Nr   )r>   r"   r3   )fwdfwd_argssaved_valuesr   r   r<     s    z=AutogradFunctionApply.__call__.<locals>.ApplyTemplate.forwardc                    s    d| S r   r   )r>   Zgrad)bwdr   r   r   rA     s    z>AutogradFunctionApply.__call__.<locals>.ApplyTemplate.backwardN)r%   r&   r'   rH   r<   rA   r   r   r   r   r   r   r   ApplyTemplate  s   r   )r   r   rF   rb   r    )	r   r   r   r   Z
fwd_kwargsr   Zlength_of_tensor_argsZnew_fwd_argsr   r   r   r   r     s    zAutogradFunctionApply.__call__r$   r   r   r   r   r     s   r   )N)N)7typingr   r   r   r   Ztorch.utils._pytreeutilsZ_pytreer5   Ztorch._C._functorchr   r   r   r   Ztorch._functorch.apisr	   Ztorch._functorch.utilsr
   Ztorch._functorch.vmapr   r   r   r   r   Z
torch._opsr   Ztorch.autograd.forward_adr   r   r   Zpy_implZGradZJvpr-   r)   rR   r8   r\   rc   rg   ZVmaprn   rk   ZFunctionalizerq   ro   r   r   r|   rt   r   r   r   r   r   r   r   r   r   <module>   sJ   


?
Y

<

u? 
! 
;	