U
    yhM                     @   s  d Z ddlmZmZmZ ddlZddlmZ ddlm  m	Z
 ddlm  m  mZ ddlm  m  m  mZ ddlmZ ddlmZ ddlmZmZmZ ddlmZ dd	lmZmZ d
dddddgZddhZ ee! ee! dddZ"G dd deZ#G dd
 d
e#Z$G dd de#Z%G dd de#Z&edej'j(j#dZ)G dd de#Z*G dd de*Z+G dd de*Z,G d d de*Z-dS )!zQuantized convolution modules.    )OptionalListTypeVarN)ops)	_size_1_t)_single_pair_triple)fuse_conv_bn_weights   )_quantize_weightWeightedQuantizedModuleConv1dConv2dConv3dConvTranspose1dConvTranspose2dConvTranspose3dzerosreflect)paddingreturnc                 C   sB   g }t | }t|D ](}tdD ]}|| || d   q q|S )N   r   )lenrangeappend)r    _reversed_padding_repeated_twiceNidx_ r    T/var/www/html/venv/lib/python3.8/site-packages/torch/ao/nn/quantized/modules/conv.py_reverse_repeat_padding   s    r"   c                       s   e Zd Zd&ddZd'dd fd	d
Zdd Zdd Zdd Zdd Z fddZ	e
jjdd Z fddZe
jjdd Zdd Zdd Zed(dd Zed)d"d#Zed$d% Z  ZS )*_ConvNdr   r   Tr   Nc                 C   s   t d S NNotImplementedError)selfin_channelsout_channelskernel_sizestrider   dilationgroupsbiaspadding_modedevicedtyper    r    r!   __init__%   s    z_ConvNd.__init__)r   c                    s2  ||d}t    ||	 dkr(td||	 dkr<td|| _|| _|| _|| _|| _|| _|| _	|| _
|	| _|tkrtd| d|| _| j	r||| j g}n||| j g}tj|t| fddtjdd	d
 | D }|
rtj|fdtjidd
 | D nd }| || d| _d| _d S )Nr0   r1   r   z'in_channels must be divisible by groupsz(out_channels must be divisible by groupsz'padding_mode' z* is not supported by quantized convolutionr   )scale
zero_pointr1   c                 S   s   i | ]\}}|d kr||qS r1   r    .0kvr    r    r!   
<dictcomp>M   s       z!_ConvNd._init.<locals>.<dictcomp>r1   c                 S   s   i | ]\}}|d kr||qS r6   r    r7   r    r    r!   r;   P   s       g      ?)superr2   
ValueErrorr(   r)   r*   r+   r   r,   
transposedoutput_paddingr-   _SUPPORTED_PADDINGr/   torchZ_empty_affine_quantizedlistqint8itemsr   floatset_weight_biasr4   r5   )r'   r(   r)   r*   r+   r   r,   r>   r?   r-   r.   r/   r0   r1   factory_kwargsZweight_shapeqweight
bias_float	__class__r    r!   _init+   sN    


  z_ConvNd._initc                 C   s   t d S r$   r%   )r'   rH   rI   r    r    r!   rF   V   s    z_ConvNd.set_weight_biasc                 C   s   t d S r$   r%   r'   r    r    r!   r.   Y   s    z_ConvNd.biasc                 C   s   t d S r$   r%   rM   r    r    r!   _weight_bias\   s    z_ConvNd._weight_biasc                 C   s   d}| j dt| j  kr |d7 }| jdt| j kr<|d7 }| jdt| j krX|d7 }| jdkrj|d7 }|  d kr~|d	7 }|jf | jS )
Nzq{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}, scale={scale}, zero_point={zero_point})r   z, padding={padding})r   z, dilation={dilation}z!, output_padding={output_padding}r   z, groups={groups}z, bias=False)r   r   r,   r?   r-   r.   format__dict__)r'   sr    r    r!   
extra_repr_   s    
z_ConvNd.extra_reprc                    s`   t  ||| |  \}}|||d < |||d < t| j||d < t| j||d < d S )Nweightr.   r4   r5   )r<   _save_to_state_dictrN   rA   Ztensorr4   r5   )r'   ZdestinationprefixZ	keep_varswbrJ   r    r!   rT   z   s    z_ConvNd._save_to_state_dictc                 C   sH   |   \}}| j| j| j| j| j| j| j| j| j	| j
||| j| j| jfS r$   )rN   r(   r)   r*   r+   r   r,   r>   r?   r-   r/   r4   r5   trainingr'   rV   rW   r    r    r!   __getstate__   s"    z_ConvNd.__getstate__c              	      s   |  ||d  ||d   ||d  ||d  t||d  | _||d  t||d  | _||d  t |||d||| d S )NrS   r.   r4   r5   F)rF   poprE   r4   intr5   r<   _load_from_state_dict)r'   Z
state_dictrU   Zlocal_metadatastrictZmissing_keysZunexpected_keysZ
error_msgsrJ   r    r!   r]      s&    
 
     z_ConvNd._load_from_state_dictc                 C   s   |d | _ |d | _|d | _|d | _|d | _|d | _|d | _|d | _|d	 | _|d
 | _	| 
|d |d  |d | _|d | _|d | _d S )Nr   r   r                     	   
               )r(   r)   r*   r+   r   r,   r>   r?   r-   r/   rF   r4   r5   rX   )r'   stater    r    r!   __setstate__   s    











z_ConvNd.__setstate__c                 C   s6   t | t | }tjj| |  }|| |S r$   )type__new__rA   nnModuler2   rZ   rl   )r'   memoZnew_instancerk   r    r    r!   __deepcopy__   s
    
z_ConvNd.__deepcopy__c                 C   s
   |  i S r$   )rr   rM   r    r    r!   __copy__   s    z_ConvNd.__copy__c              
   C   s   |dkr|j  }||j |jtjks0tdt|j |}| |j|j	|j
|j|j|j|j|jdk	|j	}|||j |dks|jtjkr|S | \}}t||_t||_|S dS )z/Creates a qconv object and returns it.
        N*Weight observer must have a dtype of qint8)qconfigrS   r1   rA   rC   AssertionErrorr   rE   r(   r)   r*   r+   r   r,   r-   r.   r/   rF   calculate_qparamsr4   r\   r5   )clsmodactivation_post_processweight_post_processrH   qconv	act_scaleact_zpr    r    r!   	get_qconv   s,    

    

z_ConvNd.get_qconvFc                 C   s  t |drlt|| jkrLt|j|j|jj|jj|jj	|jj|jj\|_|_t |ds^t
d|j}|j}nt|| jkst
d| j d | jj d tt| t |dst
dt |dsd n|j}t|| j| j| jfkr|d	 }|j }| |||S )
Nweight_fake_quantrz   z,Input QAT module must have observer attached nnq..from_float only works for z	 but got:ru   -Input float module must have qconfig defined.r   )hasattrrm   _NNIQAT_CONV_BN_MODULEr
   rS   r.   ZbnZrunning_meanZrunning_varepsrv   r   rz   _FLOAT_MODULE__name__str_NNI_CONV_RELU_MODULE_NNI_CONV_ADD_MODULE_NNI_CONV_ADD_RELU_MODULEru   r   )rx   ry   use_precomputed_fake_quantr{   rz   r    r    r!   
from_float   sL    
     
 
z_ConvNd.from_floatc                 C   sj   | |j |j|j|j|j|j|j|jdk	|j|j	j
|j	jd}| }|||j t||_t||_|S )a  Create a (fbgemm/qnnpack) quantized module from a reference quantized module
        Args:
            ref_qconv (Module): a reference quantized  module, either produced by torch.ao.quantization
                                utilities or provided by the user
            output_scale (float): scale for output Tensor
            output_zero_point (int): zero point for output Tensor
        Nr3   )r(   r)   r*   r+   r   r,   r-   r.   r/   rS   r0   r1   get_quantized_weightrF   rE   r4   r\   r5   )rx   Z	ref_qconvoutput_scaleoutput_zero_pointr|   rH   r    r    r!   from_reference   s$    	

z_ConvNd.from_reference)r   r   r   r   Tr   NN)r   NN)N)F)r   
__module____qualname__r2   rL   rF   r.   rN   rR   rT   rA   jitZexportrZ   r]   rl   rr   rs   classmethodr   staticmethodr   r   __classcell__r    r    rJ   r!   r#   $   sB                

   +

r#   c                       s   e Zd ZU dZejZejZ	e
jZdZded< dZded< deeeeeeeeed		 fd
dZdd Zejeej ddddZdd Zdd Zdd Zdd ZedddZ  Z S )r   a`  Applies a 1D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv1d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv1d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> m = nn.quantized.Conv1d(16, 33, 3, stride=2)
        >>> input = torch.randn(20, 16, 100)
        >>> # quantize input to quint8
        >>> # xdoctest: +SKIP
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0,
        ...                                     dtype=torch.quint8)
        >>> output = m(q_input)

    Nr   r   r   r   Tr   )	r(   r)   r*   r+   r   r,   r-   r.   r/   c                    sd   |
|d}t |}t |}t|tr(|nt |}t |}t j||||||dt d|||	f| d S Nr3   Fr   )r   
isinstancer   r<   rL   r'   r(   r)   r*   r+   r   r,   r-   r.   r/   r0   r1   rG   rJ   r    r!   r2   9  s(    
         zConv1d.__init__c                 C   s   dS )NZQuantizedConv1dr    rM   r    r    r!   	_get_nameQ  s    zConv1d._get_namerV   rW   r   c                 C   sV   | j dkr.tjj||| j| j| j| j| _	n$tjj||| jt
d| j| j| _	d S Nr   r   )r/   rA   r   	quantizedZconv1d_prepackr+   r   r,   r-   _packed_paramsr   rY   r    r    r!   rF   T  s"    
         zConv1d.set_weight_biasc                 C   s   t jj| j\}}||fS r$   )rA   r   r   Zconv1d_unpackr   rY   r    r    r!   rN   ]  s    zConv1d._weight_biasc                 C   s   |   d S Nr   rN   rM   r    r    r!   rS   a  s    zConv1d.weightc                 C   s   |   d S Nr   r   rM   r    r    r!   r.   d  s    zConv1d.biasc                 C   s\   t |jdkrtd| jdkrDt| jd d }tj||| jd}tj	
|| j| j| jS )Nr_    Input shape must be `(N, C, L)`!r   r   mode)r   shaper=   r/   r"   r   Fpadr   r   Zconv1dr   r4   r5   r'   inputr   r    r    r!   forwardg  s    
zConv1d.forwardFc                 C   s   t j| ||dS zCreates a quantized module from a float module or qparams_dict.

        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
              utilities or provided by the user
        )r   r#   r   rx   ry   r   r    r    r!   r   s  s    zConv1d.from_float)r   r   r   r   Tr   NN)F)!r   r   r   __doc__ro   r   r   nniqatZConvBn1dr   nniZ
ConvReLU1dr   r   __annotations__r   r\   r   boolr   r2   r   rA   Tensorr   rF   rN   rS   r.   r   r   r   r   r    r    rJ   r!   r     sB   
"        	c                       s   e Zd ZdZejZejZ	e
jZe
jZe
jZd fdd	Zd	d
 Zejeej ddddZdd Zdd Zdd Zdd ZedddZ  ZS )r   a  Applies a 2D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv2d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv2d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> # With square kernels and equal stride
        >>> m = nn.quantized.Conv2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.quantized.Conv2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2), dilation=(3, 1))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> # quantize input to quint8
        >>> # xdoctest: +SKIP
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)

    r   r   Tr   Nc                    sV   |
|d}t |}t |}t |}t |}t j||||||dt d|||	f| d S r   )r   r<   rL   r   rJ   r    r!   r2     s(    
         zConv2d.__init__c                 C   s   dS )NZQuantizedConv2dr    rM   r    r    r!   r     s    zConv2d._get_namer   c                 C   sV   | j dkr.tjj||| j| j| j| j| _	n$tjj||| jt
d| j| j| _	d S r   )r/   rA   r   r   Zconv2d_prepackr+   r   r,   r-   r   r   rY   r    r    r!   rF     s"    
          zConv2d.set_weight_biasc                 C   s
   | j  S r$   r   unpackrM   r    r    r!   rN     s    zConv2d._weight_biasc                 C   s   |   d S r   r   rM   r    r    r!   rS     s    zConv2d.weightc                 C   s   |   d S r   r   rM   r    r    r!   r.     s    zConv2d.biasc                 C   sT   t |jdkrtd| jdkr<t| j}tj||| jd}tj	
|| j| j| jS )Nr`   #Input shape must be `(N, C, H, W)`!r   r   )r   r   r=   r/   r"   r   r   r   r   r   Zconv2dr   r4   r5   r   r    r    r!   r     s    

   zConv2d.forwardFc                 C   s   t j| ||dS r   r   r   r    r    r!   r     s    zConv2d.from_float)r   r   r   r   Tr   NN)F)r   r   r   r   ro   r   r   r   ZConvBn2dr   r   Z
ConvReLU2dr   Z	ConvAdd2dr   ZConvAddReLU2dr   r2   r   rA   r   r   rF   rN   rS   r.   r   r   r   r   r    r    rJ   r!   r   ~  s.   %             c                       s   e Zd ZU dZejZejZ	e
jZdZded< dZded< d fd	d
	Zdd Zejeej ddddZdd Zdd Zdd Zdd ZedddZ  ZS )r   a  Applies a 3D convolution over a quantized input signal composed of
    several quantized input planes.

    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.Conv3d`.

    .. note::
        Only `zeros` is supported for the :attr:`padding_mode` argument.

    .. note::
        Only `torch.quint8` is supported for the input data type.


    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point

    See :class:`~torch.nn.Conv3d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> # With square kernels and equal stride
        >>> m = nn.quantized.Conv3d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2))
        >>> # non-square kernels and unequal stride and with padding and dilation
        >>> m = nn.quantized.Conv3d(16, 33, (3, 5, 5), stride=(1, 2, 2), padding=(1, 2, 2), dilation=(1, 2, 2))
        >>> input = torch.randn(20, 16, 56, 56, 56)
        >>> # quantize input to quint8
        >>> # xdoctest: +SKIP
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)

    Nr   r   r   r   Tr   c                    sf   |	dkst d|
|d}t|}t|}t|}t|}t j||||||dtd|||	f| d S )Nr   z*Conv3d does not support reflection paddingr3   Fr   )rv   r	   r<   rL   r   rJ   r    r!   r2     s*    
         zConv3d.__init__c                 C   s   dS )NZQuantizedConv3dr    rM   r    r    r!   r     s    zConv3d._get_namer   c                 C   sV   | j dkr.tjj||| j| j| j| j| _	n$tjj||| jt
d| j| j| _	d S r   )r/   rA   r   r   Zconv3d_prepackr+   r   r,   r-   r   r	   rY   r    r    r!   rF   !  s"    
          zConv3d.set_weight_biasc                 C   s
   | j  S r$   r   rM   r    r    r!   rN   )  s    zConv3d._weight_biasc                 C   s   |   d S r   r   rM   r    r    r!   rS   ,  s    zConv3d.weightc                 C   s   |   d S r   r   rM   r    r    r!   r.   /  s    zConv3d.biasc                 C   sT   t |jdkrtd| jdkr<t| j}tj||| jd}tj	
|| j| j| jS )Nra   z&Input shape must be `(N, C, D, H, W)`!r   r   )r   r   r=   r/   r"   r   r   r   r   r   Zconv3dr   r4   r5   r   r    r    r!   r   2  s    

   zConv3d.forwardFc                 C   s   t j| ||dS r   r   r   r    r    r!   r   >  s    zConv3d.from_float)r   r   r   r   Tr   NN)F)r   r   r   r   ro   r   r   r   ZConvBn3dr   r   Z
ConvReLU3dr   r   r   r   r2   r   rA   r   r   rF   rN   rS   r.   r   r   r   r   r    r    rJ   r!   r     s.   
%             MOD)boundc                       s`   e Zd ZeZd fdd	Zee ee ee ee dddZe	ddd	Z
ed
d Z  ZS )_ConvTransposeNdNc                    sL   |dkrt d| jj ||d}t j|||||||||	|
|f| d S )Nr   z+Only "zeros" padding mode is supported for r3   )r=   rK   r   r<   rL   )r'   r(   r)   r*   r+   r   r,   r>   r?   r-   r.   r/   r0   r1   rG   rJ   r    r!   r2   P  s$    
        z_ConvTransposeNd.__init__)r*   r,   r   r   c                 C   sN   t jtt g }tt|D ]*}|| || d  ||  }|| q|S r   )rA   r   Zannotater   r\   r   r   r   )r'   r*   r,   r   resZkdxr   r    r    r!   _input_padding]  s
    z_ConvTransposeNd._input_paddingFc           	      C   s   d| j  d | jj  }t|| jks,t|t|ds>td|j }||j |jtj	ksftdt
|j |}| |j|j|j|j|j|j|j|jdk	|j|j
}|||j t|dr|jjtjkr|S |j \}}t||_t||_|S dS )zCreates a quantized module from a float module or qparams_dict.
        Args:
            mod (Module): a float module, either produced by torch.ao.quantization
              utilities or provided by the user
        r   r   ru   r   rt   Nrz   )r   r   rm   rv   r   ru   rS   r1   rA   rC   r   rE   r(   r)   r*   r+   r   r?   r-   r.   r,   r/   rF   rz   rw   r4   r\   r5   )	rx   ry   r   msgr{   rH   r|   r}   r~   r    r    r!   r   d  s:    

     

z_ConvTransposeNd.from_floatc                 C   sn   | |j |j|j|j|j|j|j|jdk	|j|j	|j
j|j
jd}| }|||j t||_t||_|S )a  Create a (fbgemm/qnnpack) quantized module from a reference quantized module
        Args:
            ref_qconvt (Module): a reference quantized  module, either produced by torch.ao.quantization
                                 utilities or provided by the user
            output_scale (float): scale for output Tensor
            output_zero_point (int): zero point for output Tensor
        Nr3   )r(   r)   r*   r+   r   r?   r-   r.   r,   r/   rS   r0   r1   r   rF   rE   r4   r\   r5   )rx   
ref_qconvtr   r   r|   rH   r    r    r!   r     s&    	

z_ConvTransposeNd.from_reference)NN)F)r   r   r   r   r   r2   r   r\   r   r   r   r   r   r   r    r    rJ   r!   r   L  s      $r   c                	       sv   e Zd ZdZejZd fdd	Zd	d
 Ze	j
ee	j
 ddddZdd Zdd Zdd Zdd Zedd Z  ZS )r   a  Applies a 1D transposed convolution operator over an input image
    composed of several input planes.
    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.ConvTranspose1d`.

    .. note:: Currently only the QNNPACK engine is implemented.
        Please, set the `torch.backends.quantized.engine = 'qnnpack'`

    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv1d`

    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point
    See :class:`~torch.nn.ConvTranspose2d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> torch.backends.quantized.engine = 'qnnpack'
        >>> from torch.ao.nn import quantized as nnq
        >>> # With square kernels and equal stride
        >>> m = nnq.ConvTranspose1d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nnq.ConvTranspose1d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> input = torch.randn(20, 16, 50)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)
        >>> # exact output size can be also specified as an argument
        >>> input = torch.randn(1, 16, 12)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> downsample = nnq.Conv1d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nnq.ConvTranspose1d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(q_input)
        >>> h.size()
        torch.Size([1, 16, 6])
        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12])
    r   r   Tr   Nc                    sZ   ||d}t |}t |}t |}t |	}	t |}t j||||||	d||||
f| d S Nr3   T)r   r<   r2   r'   r(   r)   r*   r+   r   r?   r-   r.   r,   r/   r0   r1   rG   rJ   r    r!   r2     s*    
         zConvTranspose1d.__init__c                 C   s   dS )NZQuantizedConvTranspose1dr    rM   r    r    r!   r     s    zConvTranspose1d._get_namer   c              	   C   s*   t jj||| j| j| j| j| j| _	d S r$   )
rA   r   r   Zconv_transpose1d_prepackr+   r   r?   r,   r-   r   rY   r    r    r!   rF     s         zConvTranspose1d.set_weight_biasc                 C   s   t jj| j\}}||fS r$   )rA   r   r   Zconv_transpose1d_unpackr   rY   r    r    r!   rN     s    zConvTranspose1d._weight_biasc                 C   s   |   \}}|S r$   r   r'   rV   r   r    r    r!   rS     s    zConvTranspose1d.weightc                 C   s   |   \}}|S r$   r   r'   r   rW   r    r    r!   r.     s    zConvTranspose1d.biasc                 C   s0   t |jdkrtdtjj|| j| j| j	S )Nr_   r   )
r   r   r=   rA   r   r   Zconv_transpose1dr   r4   r5   r'   r   r    r    r!   r     s       zConvTranspose1d.forwardc                 C   s   t | |||S r$   r   r   rx   r   r   r   r    r    r!   r     s    zConvTranspose1d.from_reference)	r   r   r   r   Tr   r   NN)r   r   r   r   ro   r   r   r2   r   rA   r   r   rF   rN   rS   r.   r   r   r   r   r    r    rJ   r!   r     s(   +               c                	       sv   e Zd ZdZejZd fdd	Zd	d
 Ze	j
ee	j
 ddddZdd Zdd Zdd Zdd Zedd Z  ZS )r   a~  Applies a 2D transposed convolution operator over an input image
    composed of several input planes.
    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.ConvTranspose2d`.

    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv2d`

    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point
    See :class:`~torch.nn.ConvTranspose2d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> # QNNPACK or FBGEMM as backend
        >>> torch.backends.quantized.engine = 'qnnpack'
        >>> # With square kernels and equal stride
        >>> import torch.ao.nn.quantized as nnq
        >>> m = nnq.ConvTranspose2d(16, 33, 3, stride=2)
        >>> # non-square kernels and unequal stride and with padding
        >>> m = nnq.ConvTranspose2d(16, 33, (3, 5), stride=(2, 1), padding=(4, 2))
        >>> input = torch.randn(20, 16, 50, 100)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)
        >>> # exact output size can be also specified as an argument
        >>> input = torch.randn(1, 16, 12, 12)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> downsample = nnq.Conv2d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nnq.ConvTranspose2d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(q_input)
        >>> h.size()
        torch.Size([1, 16, 6, 6])
        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12, 12])
    r   r   Tr   Nc                    sZ   ||d}t |}t |}t |}t |	}	t |}t j||||||	d||||
f| d S r   )r   r<   r2   r   rJ   r    r!   r2   )  s*    
         zConvTranspose2d.__init__c                 C   s   dS )NZQuantizedConvTranspose2dr    rM   r    r    r!   r   7  s    zConvTranspose2d._get_namer   c              	   C   s*   t jj||| j| j| j| j| j| _	d S r$   )
rA   r   r   Zconv_transpose2d_prepackr+   r   r?   r,   r-   r   rY   r    r    r!   rF   :  s         zConvTranspose2d.set_weight_biasc                 C   s   t jj| j\}}||fS r$   )rA   r   r   Zconv2d_unpackr   rY   r    r    r!   rN   ?  s    zConvTranspose2d._weight_biasc                 C   s   |   \}}|S r$   r   r   r    r    r!   rS   C  s    zConvTranspose2d.weightc                 C   s   |   \}}|S r$   r   r   r    r    r!   r.   G  s    zConvTranspose2d.biasc                 C   s.   t |jdkrtdtj|| j| j| jS )Nr`   r   )	r   r   r=   r   r   Zconv_transpose2dr   r4   r5   r   r    r    r!   r   K  s       zConvTranspose2d.forwardc                 C   s   t | |||S r$   r   r   r    r    r!   r   S  s    zConvTranspose2d.from_reference)	r   r   r   r   Tr   r   NN)r   r   r   r   ro   r   r   r2   r   rA   r   r   rF   rN   rS   r.   r   r   r   r   r    r    rJ   r!   r     s(   )               c                	       sv   e Zd ZdZejZd fdd	Zd	d
 Ze	j
ee	j
 ddddZdd Zdd Zdd Zdd Zedd Z  ZS )r   a  Applies a 3D transposed convolution operator over an input image
    composed of several input planes.
    For details on input arguments, parameters, and implementation see
    :class:`~torch.nn.ConvTranspose3d`.

    .. note:: Currently only the FBGEMM engine is implemented.
        Please, set the `torch.backends.quantized.engine = 'fbgemm'`

    For special notes, please, see :class:`~torch.ao.nn.quantized.Conv3d`

    Attributes:
        weight (Tensor):     packed tensor derived from the learnable weight
                             parameter.
        scale (Tensor):      scalar for the output scale
        zero_point (Tensor): scalar for the output zero point
    See :class:`~torch.nn.ConvTranspose3d` for other attributes.

    Examples::

        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_QENGINE)
        >>> torch.backends.quantized.engine = 'fbgemm'
        >>> from torch.ao.nn import quantized as nnq
        >>> # With cubic kernels and equal stride
        >>> m = nnq.ConvTranspose3d(16, 33, 3, stride=2)
        >>> # non-cubic kernels and unequal stride and with padding
        >>> m = nnq.ConvTranspose3d(16, 33, (3, 3, 5), stride=(2, 1, 1), padding=(4, 2, 2))
        >>> input = torch.randn(20, 16, 50, 100, 100)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> output = m(q_input)
        >>> # exact output size can be also specified as an argument
        >>> input = torch.randn(1, 16, 12, 12, 12)
        >>> q_input = torch.quantize_per_tensor(input, scale=1.0, zero_point=0, dtype=torch.quint8)
        >>> downsample = nnq.Conv3d(16, 16, 3, stride=2, padding=1)
        >>> upsample = nnq.ConvTranspose3d(16, 16, 3, stride=2, padding=1)
        >>> h = downsample(q_input)
        >>> h.size()
        torch.Size([1, 16, 6, 6, 6])
        >>> # xdoctest: +SKIP("FIXME: output_size is not a parameter)
        >>> output = upsample(h, output_size=input.size())
        >>> output.size()
        torch.Size([1, 16, 12, 12, 12])
    r   r   Tr   Nc                    sZ   ||d}t |}t |}t |}t |	}	t |}t j||||||	d||||
f| d S r   )r	   r<   r2   r   rJ   r    r!   r2     s*    
         zConvTranspose3d.__init__c                 C   s   dS )NZQuantizedConvTranspose3dr    rM   r    r    r!   r     s    zConvTranspose3d._get_namer   c              	   C   s*   t jj||| j| j| j| j| j| _	d S r$   )
rA   r   r   Zconv_transpose3d_prepackr+   r   r?   r,   r-   r   rY   r    r    r!   rF     s         zConvTranspose3d.set_weight_biasc                 C   s   t jj| j\}}||fS r$   )rA   r   r   Zconv3d_unpackr   rY   r    r    r!   rN     s    zConvTranspose3d._weight_biasc                 C   s   |   \}}|S r$   r   r   r    r    r!   rS     s    zConvTranspose3d.weightc                 C   s   |   \}}|S r$   r   r   r    r    r!   r.     s    zConvTranspose3d.biasc                 C   s.   t |jdkrtdtj|| j| j| jS )Nra   z&Input shape must be `(N, C, T, H, W)`!)	r   r   r=   r   r   Zconv_transpose3dr   r4   r5   r   r    r    r!   r     s       zConvTranspose3d.forwardc                 C   s   t | |||S r$   r   r   r    r    r!   r     s    zConvTranspose3d.from_reference)	r   r   r   r   Tr   r   NN)r   r   r   r   ro   r   r   r2   r   rA   r   r   rF   rN   rS   r.   r   r   r   r   r    r    rJ   r!   r   X  s(   +               ).r   typingr   r   r   rA   Ztorch.nnro   Ztorch.nn.functionalZ
functionalr   Ztorch.ao.nn.intrinsicZaoZ	intrinsicr   Ztorch.ao.nn.intrinsic.qatZqatr   Z
torch._opsr   Ztorch.nn.common_typesr   Ztorch.nn.modules.utilsr   r   r	   Ztorch.nn.utilsr
   utilsr   r   __all__r@   r\   r"   r#   r   r   r   modulesconvr   r   r   r   r   r    r    r    r!   <module>   s4   	 mnefT][