U
    zhe                     @   s  U d dl Z d dlZd dlZd dlZd dlZd dlZd dlmZmZm	Z	m
Z
mZmZ d dlZd dlmZ d dlmZ d dlmZ d dlmZ ddlmZmZ d dlZd dlmZ dd	d
dgZdddZddddddhZd?eeje ejdddZ i Z!ej"ed	f e#d< G dd	 d	Z$e j%G dd dZ&eedddZ'edddd Z(edd!d"d#Z)eej*eef d$d%d&Z+edd'd(d)Z,ej-e.d*d+d,Z/eejdd-d.d/Z0ej1edd0d1d2Z2d3d4 Z3d5d6 Z4d@d8d9Z5d:d; Z6dAd=d>Z7dS )B    N)FunctionSchemaOperatorName
SchemaKindBaseTypeListTypeBaseTy)AbstractImplCtx)get_ctx   )autograd_kernel_indirectionconstruct_autograd_kernel)infer_schema	custom_opCustomOpr	   r   CPUCUDA)cpucudaZprimZprimsZatenattorchZpytorch)qualnamemanual_schemareturnc                    s    fdd}|S )aR  Creates a new CustomOp object.

    WARNING: if you're a user, please do not use this directly
    (instead use the torch._custom_ops APIs).
    Also please see the following for a detailed guide on custom ops.
    https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk

    In PyTorch, defining an op (short for "operator") is a two step-process:
    - we need to define (create) the op
    - we need to implement behavior for how the operator interacts with
      various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.

    This entrypoint defines the CustomOp object (the first step);
    you must then perform the second step by calling various methods on
    the CustomOp object.

    This API is used as a decorator (see examples).

    Arguments:
        qualname (str): Should be a string that looks like
            "namespace::operator_name". Operators in PyTorch need a namespace to
            avoid name collisions; a given operator may only be created once.
            If you are writing a Python library, we recommend the namespace to
            be the name of your top-level module. The operator_name must be
            the same as the name of the function you pass to custom_op
            (see examples).
        manual_schema (Optional[str]): Each PyTorch operator needs a schema that
            tells PyTorch the types of the inputs/outputs. If None (default),
            we will infer the schema from the type annotations on the function
            (see examples). Otherwise, if you don't want to use type annotations,
            you may provide us the schema string.

    Example::
        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
        >>> import numpy as np
        >>> from torch import Tensor
        >>>
        >>> # Step 1: define the CustomOp.
        >>> # We need to provide the decorator a "prototype function"
        >>> # (a function with Python ellipses as the body).
        >>> @custom_op("my_library::numpy_sin")
        >>> def numpy_sin(x: Tensor) -> Tensor:
        >>>     ...
        >>>
        >>> # numpy_sin is now an instance of class CustomOp
        >>> print(type(numpy_sin))
        >>>
        >>> # Step 2: Register an implementation for various PyTorch subsystems
        >>>
        >>> # Register an implementation for CPU tensors
        >>> @numpy_sin.impl('cpu')
        >>> def numpy_sin_impl_cpu(x):
        >>>     return torch.from_numpy(np.sin(x.numpy()))
        >>>
        >>> # Register an implementation for CUDA tensors
        >>> @numpy_sin.impl('cuda')
        >>> def numpy_sin_impl_cuda(x):
        >>>     return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
        >>>
        >>> x = torch.randn(3)
        >>> numpy_sin(x)  # calls numpy_sin_impl_cpu
        >>>
        >>> x_cuda = x.cuda()
        >>> numpy_sin(x)  # calls numpy_sin_impl_cuda

    c           	   	      s(  t | stdt|  t\}}t| | j|krXtd d| d| j d d krht| n }| | }t	|}t
|  d k	rt||  t|d}|| t||j}t|||||dd}| j|_| j|_| j|_t||jd	tt| tj|ttt| |S )
NzDcustom_op(...)(func): Expected `func` to be a Python function, got: zcustom_op(qualname='z-', ...)(func): expected `func` to have name 'z' but got 'zX'. Please either change the name of `func` or the qualname that is passed to `custom_op`FRAGMENTT_private_accessAutograd)inspect
isfunction
ValueErrortypeparse_qualnamevalidate_namespace__name__r   r   parsevalidate_schema validate_function_matches_schemalibraryLibrarydefinefind_ophandle_or_thrownamer   
__module____doc__impl_opnamer   weakrefproxyr   _C#_dispatch_set_report_error_callback	functoolspartialreport_error_callback)	funcnsr+   schema
schema_strfunction_schemalibophandleresultr   r    G/var/www/html/venv/lib/python3.8/site-packages/torch/_custom_op/impl.pyinnerv   s>    




 zcustom_op.<locals>.innerr@   )r   r   rB   r@   r?   rA   r   0   s    F+global_registryc                       s   e Zd ZdZdd fdd
Zdd Zd,d	d
Zdd Zdd Zdd Z	dd Z
dd Zd-ejeeje f ejdddZdd ZejdddZd.ejdddZdd  Zd!d" Zd#d$ Zd%d& Zd/d'd(Zd0d*d+Z  ZS )1r   a3  Class for custom operators in PyTorch.

    Use the CustomOp API to create user-defined custom operators that behave
    just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
    comes to various PyTorch subsystems (like torch.compile).

    To construct a `CustomOp`, use `custom_op`.
    Fr   c                   sh   t    |std| d| }|| _|| _|| _|| _|| _|| _d | _	i | _
d| _| t| j< d S )Nz|The CustomOp constructor is private and we do not guarantee BC for it. Please use custom_op(...) to create a CustomOp object::F)super__init__RuntimeError_schema_cpp_ns_lib	_ophandler/   	_qualnamer#   _impls'_registered_autograd_kernel_indirectionrC   )selfr<   cpp_nsr9   operator_namer=   r   r+   	__class__r@   rA   rF      s     
zCustomOp.__init__c                 C   s0   | j r
t| j| jtt| d d| _ d S )Nr   T)rN   AssertionErrorrJ   r.   r/   r   r0   r1   rO   r@   r@   rA   %_register_autograd_kernel_indirection   s    
z.CustomOp._register_autograd_kernel_indirection   c              
   C   s   |  |rJ| j| }|d k	s t|j}td| d| j d| d| d	tt	|}|j
 d|j }t||| j|< d S )NzAttempting to register a z impl for operator z that already has a z  impl registered from Python at z. This is not supported.:)	_has_implrM   rT   locationrG   rL   r   getframeinfosys	_getframefilenamelinenoFuncAndLocation)rO   kindr7   
stacklevelZfunc_and_locationrZ   framer@   r@   rA   _register_impl   s    

zCustomOp._register_implc                 C   s
   | j | S NrM   rO   ra   r@   r@   rA   	_get_impl   s    zCustomOp._get_implc                 C   s
   || j kS re   rf   rg   r@   r@   rA   rY      s    zCustomOp._has_implc                 C   s6   | ` ttj| j}t|| jr*t|| j t| j	= d S re   )
rJ   getattrr   opsrI   hasattrr/   delattrrC   rL   )rO   opnamespacer@   r@   rA   _destroy   s
    zCustomOp._destroyc                 C   s   d| j  dS )Nz<CustomOp(op="z")>)rL   rU   r@   r@   rA   __repr__   s    zCustomOp.__repr__c                 O   s   t j| jf||}|S re   )r2   Z_dispatch_call_boxedrK   )rO   argskwargsr>   r@   r@   rA   __call__   s    zCustomOp.__call__)device_typesr   c                    s6   t trgD ]}t| q fdd}|S )a  Register an implementation for a device type for this CustomOp object.

        WARNING: if you're a user, please do not use this directly
        (instead use the torch._custom_ops APIs).
        Also please see the following for a detailed guide on custom ops.
        https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk

        If the CustomOp is passed multiple Tensor inputs with different device
        types, it will dispatch to the registered implementation for the highest
        priority device type among those present.
        The supported device types, in order of priority, are {'cuda', 'cpu'}.

        This API is used as a decorator (see examples).

        Arguments:
            device_types (str or Iterable[str]): the device type(s) to register the function for.

        Examples::
            >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
            >>> import numpy as np
            >>> from torch import Tensor
            >>>
            >>> @custom_op("my_library::numpy_cos")
            >>> def numpy_cos(x: Tensor) -> Tensor:
            >>>     ...
            >>>
            >>> # Register an implementation for CPU Tensors
            >>> @numpy_cos.impl('cpu')
            >>> def numpy_cos_impl_cpu(x):
            >>>     return torch.from_numpy(np.cos(x.numpy()))
            >>>
            >>> # Register an implementation for CUDA Tensors
            >>> @numpy_cos.impl('cuda')
            >>> def numpy_cos_impl_cuda(x):
            >>>     return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
            >>>
            >>> x = torch.randn(3)
            >>> numpy_cos(x)  # calls numpy_cos_impl_cpu
            >>>
            >>> x_cuda = x.cuda()
            >>> numpy_cos(x)  # calls numpy_cos_impl_cuda

        c                    sJ   t D ]<}| j||  d t| }tjj||  q| S )Nrb   )set_check_doesnt_have_library_implrd   SUPPORTED_DEVICE_TYPE_TO_KEYr'   r.   rJ   r/   )fdevice_typeZdispatch_key_stacklevelrs   rO   r@   rA   rB   8  s    
zCustomOp.impl.<locals>.inner)
isinstancestrvalidate_device_type)rO   rs   r{   ry   rB   r@   rz   rA   r.     s    .

zCustomOp.implc                 C   s@   |  |rd S t| }t| j|r<td| d| j dd S )Nzimpl(..., device_types=z): the operator zs already has an implementation for this device type via a pre-existing torch.library or TORCH_LIBRARY registration.)rY   rw   r2   Z._dispatch_has_computed_kernel_for_dispatch_keyrL   rG   )rO   ry   keyr@   r@   rA   rv   B  s    
z(CustomOp._check_doesnt_have_library_impl)r   c                    s    fdd}|S )z2Register an implementation for a factory function.c                    s&     d|  t j jd|  | S )NfactoryZBackendSelect)rd   r'   r.   rJ   r/   rx   rU   r@   rA   rB   O  s    z$CustomOp.impl_factory.<locals>.innerr@   )rO   rB   r@   rU   rA   impl_factoryL  s    zCustomOp.impl_factoryc                    s    fdd}|S )aj  Register an abstract implementation for this operator.

        WARNING: please do not use this directly (and instead use the torch._custom_ops
        APIs). Also please see the following for a detailed guide on custom ops.
        https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk

        An "abstract implementation" specifies the behavior of this operator on
        Tensors that carry no data. Given some input Tensors with certain properties
        (sizes/strides/storage_offset/device), it specifies what the properties of
        the output Tensors are.

        The abstract implementation has the same signature as the operator.
        It is run for both FakeTensors and meta tensors. To write an abstract
        implementation, assume that all Tensor inputs to the operator are
        regular CPU/CUDA/Meta tensors, but they do not have storage, and
        you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
        The abstract implementation must consist of only PyTorch operations
        (and may not directly access the storage or data of any input or
        intermediate Tensors).

        This API is used as a decorator (see examples).

        Examples::
            >>> import numpy as np
            >>> from torch import Tensor
            >>>
            >>> # Example 1: an operator without data-dependent output shape
            >>> @custom_op('my_library::custom_linear')
            >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
            >>>     ...
            >>>
            >>> @custom_linear.impl_abstract()
            >>> def custom_linear_abstract(x, weight):
            >>>     assert x.dim() == 2
            >>>     assert weight.dim() == 2
            >>>     assert bias.dim() == 1
            >>>     assert x.shape[1] == weight.shape[1]
            >>>     assert weight.shape[0] == bias.shape[0]
            >>>     assert x.device == weight.device
            >>>
            >>>     return (x @ weight.t()) + bias
            >>>
            >>> # Example 2: an operator with data-dependent output shape
            >>> @custom_op('my_library::custom_nonzero')
            >>> def custom_nonzero(x: Tensor) -> Tensor:
            >>>     ...
            >>>
            >>> @custom_nonzero.impl_abstract()
            >>> def custom_nonzero_abstract(x):
            >>>     # Number of nonzero-elements is data-dependent.
            >>>     # Since we cannot peek at the data in an abstract impl,
            >>>     # we use the ctx object to construct a new symint that
            >>>     # represents the data-dependent size.
            >>>     ctx = torch._custom_op.get_ctx()
            >>>     nnz = ctx.create_unbacked_symint()
            >>>     shape = [x.dim(), nnz]
            >>>     result = x.new_empty(shape, dtype=torch.long)
            >>>     return result
            >>>
            >>> @custom_nonzero.impl(['cpu', 'cuda'])
            >>> def custom_nonzero_impl(x):
            >>>     x_np = to_numpy(x)
            >>>     res = np.stack(np.nonzero(x_np), axis=1)
            >>>     # unbacked symbolic ints in PyTorch must be >= 2, so we
            >>>     # constrain the range to at least 2
            >>>     if res.shape[0] <= 1:
            >>>         raise RuntimeError("not supported")
            >>>     return torch.tensor(res, device=x.device)

        c                    sZ      jd d djjt  fdd}jj	|d  S )Nabstractrt   c               
      s>   fdd}t jj|  | |W  5 Q R  S Q R X d S )Nc                      s   t d d  d S )Nz<Attempted to call get_ctx() for the meta implementation for a  .You have presumably called get_ctx() because the operator has a data-dependent output shape; if so, there is no such meta implementation and this error is the correct behavior. Otherwise, please remove the call to get_ctx() in the implementation registered with impl_abstract at rG   r@   )rZ   r   r@   rA   error_on_ctx  s    zOCustomOp.impl_abstract.<locals>.inner.<locals>.f_with_ctx.<locals>.error_on_ctx)r   _libraryZabstract_implZset_ctx_getter)rp   rq   r   rx   rZ   r   r@   rA   
f_with_ctx  s    z9CustomOp.impl_abstract.<locals>.inner.<locals>.f_with_ctxMeta)
$_check_doesnt_have_library_meta_implrd   rh   rZ   rL   r4   wrapsrJ   r.   r/   )rx   r   r{   rO   r   rA   rB     s    z%CustomOp.impl_abstract.<locals>.innerr@   rO   r{   rB   r@   r   rA   impl_abstractV  s    HzCustomOp.impl_abstractc                    s    fdd}j    tjkr*|d  j} js>|d t|dksNttdd |D }|rl|d tt	j
d	tt	jd
tt	jdtt	jdtt	jdttt	jd di} jD ]0}|j|krq|dt|  d|j d qd S )Nc                    s    t d|  dj d  d S )NzCCannot use torch._custom_ops APIs to register backward formula for z. Got operator z with schema: )rG   rL   )detailr9   rO   r@   rA   error  s    z4CustomOp._check_can_register_backward.<locals>.errorznon-functional operatorzoperator with no returnsr   c                 s   s"   | ]}|j d k	o|j j V  qd S re   )
annotationZis_write).0rr@   r@   rA   	<genexpr>  s    z8CustomOp._check_can_register_backward.<locals>.<genexpr>zoperator that returns viewsintSymIntboolfloatTensorzList[Tensor]zoperator with return not in z (got ))rH   ra   r   Z
functionalreturnslenrT   anyr   r   r   r   r   r   r   r   r    listvalues)rO   r   ZretsZis_non_mutating_viewZallowed_return_typesretr@   r   rA   _check_can_register_backward  s<          

z%CustomOp._check_can_register_backwardc                 C   s^   | j r
d S t| jdr*td| j ddD ]*}t| j|r.td| j d| dq.d S )NCompositeImplicitAutogradz3impl_backward/impl_save_for_backward: the operator a3   already has an implementation for this device type via a pre-existing registration to DispatchKey::CompositeImplicitAutograd.CompositeImplicitAutograd operators do not need an autograd formula; instead, the operator will decompose into its constituents and those can have autograd formulas defined on them.)r   ZAutogradCPUZAutogradCUDAz; already has an Autograd kernel registered to DispatchKey::z vi a pre-existing torch.library or TORCH_LIBRARY registration. Please either remove those registrations or don't use the torch._custom_ops APIs)rN   r2   %_dispatch_has_kernel_for_dispatch_keyrL   rG   )rO   r   r@   r@   rA   (_check_doesnt_have_library_autograd_impl  s    
z1CustomOp._check_doesnt_have_library_autograd_implc                 C   sr   |  drd S t| jdr.t| jds.d S t| jdrNtd| j dt| jdrntd| j dd S )Nr   ZCompositeExplicitAutogradr   r   z!impl_abstract(...): the operator a-   already has an implementation for this device type via a pre-existing registration to DispatchKey::CompositeImplicitAutograd.CompositeImplicitAutograd operators do not need an abstract impl; instead, the operator will decompose into its constituents and those can have abstract impls defined on them.z already has an DispatchKey::Meta implementation via a pre-existing torch.library or TORCH_LIBRARY registration. Please either remove that registration or don't call impl_abstract.)rY   r2   r   rL   rG   rU   r@   r@   rA   r     s    
z-CustomOp._check_doesnt_have_library_meta_implc              	   C   sX   |  dst|  dstt| j| j| t| j| dj| dj}| 	d| d S )Nbackwardsave_for_backwardautograd)
rY   rT   r   rH   _output_differentiabilityget_oprL   rh   r7   rd   )rO   Zkernelr@   r@   rA   _register_autograd_kernel"  s    

z"CustomOp._register_autograd_kernelc                    s    fdd}|S )zyRegister a function that tells us what to save for backward.

        Please see impl_backward for more details.
        c                    sD        js  jd|  d dr@  d S )Nr   rt   r   )r   r   rN   rV   rd   rY   r   r   r   r@   rA   rB   3  s    
z.CustomOp.impl_save_for_backward.<locals>.innerr@   r   r@   r   rA   impl_save_for_backward.  s    zCustomOp.impl_save_for_backwardNc                    sl   dk	rXfdd}t ts$|  D ]}t |ts(|  q(tjjtkrX|   fdd}|S )a2  Registers a backward formula.

        WARNING: if you're a user, please do not use this directly
        (instead use the torch._custom_ops APIs).
        Also please see the following for a detailed guide on custom ops.
        https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk

        In order for the CustomOp to work with autograd, you need to register
        a backward formula. There are two pieces to this:
        1. You must give us a function to specify what to save for backward.
           Call this the "save for backward" function.
        2. You must give us a function that computes gradients. Call this the
           "backward" function.

        Use `impl_save_for_backward` to define a "save for backward" function
        that specifies what gets saved for backward. The function should accept
        two arguments ``(inputs, output)`` and return the quantities to be saved
        for backward.

        During runtime, when you call the CustomOp, PyTorch will invoke the
        "save for backward" function with the inputs and output of the CustomOp.

        Use `impl_backward` to define the "backward" function. The backward
        function must accept ``(ctx, saved, *grads)``:
        - ``ctx`` is a context object where we may provide information
        - ``saved`` is exactly what gets returned from the "save for backward"
          function
        - ``grads`` is one or more gradients. The number of gradients matches
          the number of outputs of the CustomOp.

        The backward function must return a dict that maps the name of
        an input to the CustomOp to its corresponding gradient. All inputs that
        were declared to be Tensors in the CustomOp definition must be accounted
        for in the dict. The gradient may be a Tensor or None.

        Nc                      s   t d  d S )Nzimpl_backward(output_differentiability): expected output_differentiability to be a list of bools with length equal to the number of outputs of this CustomOp got: r   r@   )output_differentiabilityr@   rA   yellc  s    z$CustomOp.impl_backward.<locals>.yellc                    sJ        js  jd|  d _drF  d S )Nr   rt   r   )r   r   rN   rV   rd   r   rY   r   r   r{   r   rO   r@   rA   rB   r  s    
z%CustomOp.impl_backward.<locals>.inner)r|   r   r   r   rH   r   )rO   r   r{   r   diffrB   r@   r   rA   impl_backward=  s    %

	zCustomOp.impl_backward)rW   )rW   )rW   )rW   )NrW   )r#   r,   __qualname__r-   rF   rV   rd   rh   rY   rn   ro   rr   typingUnionr}   IterableCallabler.   rv   r   r   r   r   r   r   r   r   __classcell__r@   r@   rR   rA   r      s.   	
 =

f%)
c                   @   s    e Zd ZU ejed< eed< dS )r`   r7   rZ   N)r#   r,   r   r   r   __annotations__r}   r@   r@   r@   rA   r`   ~  s   

r`   )rP   rQ   c                 C   s0   |j d krdn|j }t|  dt|j |S )N rD   )overload_namer2   Z_dispatch_find_schema_or_throwr}   r+   )rP   rQ   r   r@   r@   rA   r*     s     r*   )r8   r   c                 C   s:   d| krt d|  d| tkr6t d|  d|  dd S )N.zcustom_op(..., ns="zC"): expected ns to not contain any . (and be a valid variable name)zcustom_op(..., ns='z'): 'z9' is a reserved namespace, please choose something else. )r   RESERVED_NS)r8   r@   r@   rA   r"     s    
r"   )r9   r   c                 C   s:   t jj| std|  | jjd k	r6td|  d S )Nzcustom_op only supports functional operators (ops that do not mutate any inputs, do not return views of the inputs, and has at least one return). Got the following non-functional schema: zUcustom_op does not support arguments named 'self'. Please rename your argument. Got: )r   r   utilsZis_functional_schemar   	argumentsZself_arg)r9   r@   r@   rA   r%     s    r%   )r   r   c                 C   sR   |  dd}t|dkr(td|  dd|d krBtd|  |d |d fS )	NrD   r
   rW   z$Expected there to be a namespace in z;, i.e. The operator name should look something like ns::foor   zThe torch.custom_ops APIs do not handle overloads, i.e. operator names with '.' in them. Please name your operator something like ns::foo. Got: r   )splitr   r   )r   namesr@   r@   rA   r!     s    r!   )ry   r   c                 C   s&   | t kr"td|  dt   dd S )NzCustomOp.impl(device_types=[z(, ...]): we only support device_type in r   )rw   r   keys)ry   r@   r@   rA   r~     s    r~   )paramr   c                 C   s   | j tjjtjjfkS re   )ra   r   	ParameterPOSITIONAL_OR_KEYWORDKEYWORD_ONLY)r   r@   r@   rA   supported_param  s    r   )r9   r7   r   c                    s   t |tdd j D s0td tdd j D sVjt jj	k	rdtd dd j D }dd j D }fd	d
 fdd fdd}||j
j ||j
j d S )Nc                 s   s   | ]\}}t |V  qd S re   )r   r   _pr@   r@   rA   r     s     z3validate_function_matches_schema.<locals>.<genexpr>zcustom_op(..., manual_schema)(func): positional-only args, varargs, and kwargs are not supported. Please rewrite `func` to not have them. Got `func` with signature: c                 s   s    | ]\}}|j tjjk	V  qd S re   )r   r   r   emptyr   r@   r@   rA   r     s   zcustom_op(..., manual_schema)(func): When passing in a manual schema, we expect `func` to have no type annotations to avoid ambiguity. Got `func` with signature: c                 S   s&   g | ]\}}|j tjjkr||fqS r@   )ra   r   r   r   r   r+   r   r@   r@   rA   
<listcomp>  s   z4validate_function_matches_schema.<locals>.<listcomp>c                 S   s&   g | ]\}}|j tjjkr||fqS r@   )ra   r   r   r   r   r@   r@   rA   r     s   c                      s   t d d  d S )Nzcustom_op(..., manual_schema)(func): When passing in a manual schema, we expect `func`'s signature to match `manual_schema` (aside from type annotations). func's signature: , manual_schema: r   r@   r9   sigr@   rA   r     s    z/validate_function_matches_schema.<locals>.errorc                      s   t d d  d S )Nzycustom_op(..., manual_schema)(func): neither func nor manual_schema should have default arguments. Got func's signature: r   r   r@   r   r@   rA   error_default_args  s    z<validate_function_matches_schema.<locals>.error_default_argsc                    s`   t | t |kr   t| |D ]:\\}}}||jkr<   |jtjjk	sT|jd k	r   q d S re   )r   zipr+   defaultr   r   r   )Zsig_argsZschema_argsr+   r   arg)r   r   r@   rA   compare  s    
z1validate_function_matches_schema.<locals>.compare)r   	signatureall
parametersitemsr   r   return_annotation	Signaturer   r   Zflat_positionalZflat_kwarg_only)r9   r7   
positionalZ	kwargonlyr   r@   )r   r   r9   r   rA   r&     s2    
	r&   )r   r   r   c              	   C   st   |dkrt |  d|dkr,t |  d|dkr\| }t |  d| d| d| d	t |  d
| dd S )NZ	Undefineda  : There were no Tensor inputs to this operator (e.g. you passed an empty list of Tensors). If your operator is a factory function (that is, it takes no Tensors and constructs a new one), then please use CustomOp.impl_factory to register an implementation for itr   z: when running with device='Meta' tensors: there is no abstract impl registered for this CustomOp. Please register one via CustomOp.impl_abstract to get this CustomOp to work with Meta tensors)r   r   z: when running with device='z' tensors: there is no zW impl registered for this CustomOp. Please register one via CustomOp.impl(device_type='z')z%: No implementation for dispatch key z. It is likely that we have not added this functionality yet, please either open an issue or if you're feeling adventurous, use the low-level torch.library API)NotImplementedErrorlower)r   r   Zdevicer@   r@   rA   r6     s     r6   c                 C   s\   | j }tj|d}|  dd }t| j}|dd }t	|}t
||||| ddS )Nr   rD   Tr   )	namespacer   r'   r(   r+   r   r}   rH   r   r$   r   )opr8   r<   r+   r:   r9   r@   r@   rA   custom_op_from_existing%  s    

r   c                    sf    fdd}t  \}}ttj|s*|  ttj|}t||sF|  t||}t|ds`|  |jS )Nc                      s   t d  dd S )NzCould not find the operator z~. Please make sure you have already registered the operator and (if registered from C++) loaded it via torch.ops.load_library.r   r@   r   r@   rA   error_not_found1  s    
zget_op.<locals>.error_not_foundr   )r!   rk   r   rj   ri   r   )r   r   r8   r+   rm   packetr@   r   rA   r   0  s    


r   Fc                 C   s8   | t krt |  S |s$td|  dt| }t|}|S )NzCould not find custom op "z5". Did you register it via the torch._custom_ops API?)rC   rG   r   r   )r   Zalso_check_torch_libraryoverloadr>   r@   r@   rA   _find_custom_opC  s    
r   c                 C   sF   | t jjjkrd S t jjj|  }|d kr,d S |ds:d S |djS )Nr   )r   Z
_custom_opr.   rC   rY   rh   r7   )r   r   r@   r@   rA   get_abstract_implO  s    
r   Tc              	   C   s   |  d\}}| | }t|}t| |r<tjjjgng }t	|d}|j
||d t||j}	t|||||	dd}
|
  tj|	ttt|
 t| S )NrD   r   )tagsTr   )r   r   r$   r%   r   r2   Tagneeds_fixed_stride_orderr'   r(   r)   r*   r+   r   rV   r3   r4   r5   r6   r0   r1   r   )r   r9   r   r8   r+   r:   r;   r   r<   r=   r>   r@   r@   rA   _custom_op_with_schemaZ  s    
 r   )N)F)T)8dataclassesr4   r   r\   r   r0   Ztorchgen.modelr   r   r   r   r   r   r   Ztorch._Cr2   Ztorch.libraryr'   Ztorch._library.abstract_implr   r	   r   r   r   Ztorch._library.infer_schemar   __all__rw   r   r}   Optionalr   r   rC   Dictr   r   	dataclassr`   r*   r"   r%   Tupler!   r~   r   r   r   r&   Anyr6   r   r   r   r   r   r@   r@   r@   rA   <module>   sn     	  z   T	 A
