U
    T?h                     @   s  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlmZ d dl	m
Z
mZmZmZ d dlZd dlZd dlZd dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZ d dlm Z m!Z!m"Z"m#Z# e $e%Z&dddddgZ'ej(dej)dej*diZ+G dd de"Z,G dd de#Z-G dd de"Z.G dd de!Z/G dd de!Z0e/ddfe0dd fe.d!dfd"Z1G d#d$ d$Z2G d%d& d&Z3dS )'    N)Path)DictListTupleUnion)	Precision)float_to_float16_max_diff)FusionOptions)IOBindingHelper)	OnnxModel)optimize_model)torch_onnx_export)
GPT2ConfigGPT2LMHeadModel	GPT2ModelTFGPT2ModelZ
distilgpt2gpt2zgpt2-mediumz
gpt2-largezgpt2-xlMb@?g?g      @c                       s,   e Zd ZdZ fddZ fddZ  ZS )GPT2ModelNoPastState2Here we wrap a class to disable past state output.c                    s   t  | d S Nsuper__init__selfconfig	__class__ b/var/www/html/venv/lib/python3.8/site-packages/onnxruntime/transformers/models/gpt2/gpt2_helper.pyr   +   s    zGPT2ModelNoPastState.__init__c                    s   t  j|dddS )NF)	use_cachereturn_dict)r   forwardr   	input_idsr   r   r    r#   .   s    zGPT2ModelNoPastState.forward__name__
__module____qualname____doc__r   r#   __classcell__r   r   r   r    r   (   s   r   c                       s,   e Zd ZdZ fddZ fddZ  ZS )TFGPT2ModelNoPastStater   c                    s   d|_ t | d S )NF)r!   r   r   r   r   r   r    r   5   s    zTFGPT2ModelNoPastState.__init__c                    s   t  j|ddS )NF)r!   )r   callr$   r   r   r    r#   9   s    zTFGPT2ModelNoPastState.forwardr&   r   r   r   r    r,   2   s   r,   c                       s8   e Zd ZdZ fddZedd Z fddZ  ZS )MyGPT2ModelzMHere we wrap a class for Onnx model conversion for GPT2Model with past state.c                    s   t  | d S r   r   r   r   r   r    r   @   s    zMyGPT2Model.__init__c                 C   s   t | d d ttfrt| d |kr:t| d d dks>tg }t|D ]@}|tj| d | d 	d| d | d 	dfdd qJ| d t|fS | S )N   r      )dim)

isinstancetuplelistlenAssertionErrorrangeappendtorchcatZ	unsqueeze)result	num_layerZpresentir   r   r    post_processC   s    (*zMyGPT2Model.post_processc                    s&   t  j||||dd}t|| jjS NF)position_idsattention_maskpast_key_valuesr"   r   r#   r.   r>   r   n_layerr   r%   r@   rA   pastr;   r   r   r    r#   V   s    zMyGPT2Model.forward)	r'   r(   r)   r*   r   staticmethodr>   r#   r+   r   r   r   r    r.   =   s
   
r.   c                       s,   e Zd ZdZ fddZ fddZ  ZS )MyGPT2LMHeadModelzSHere we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.c                    s   t  | d S r   r   r   r   r   r    r   d   s    zMyGPT2LMHeadModel.__init__c                    s&   t  j||||dd}t|| jjS r?   rC   rE   r   r   r    r#   g   s    zMyGPT2LMHeadModel.forwardr&   r   r   r   r    rH   a   s   rH   c                       s,   e Zd ZdZ fddZ fddZ  ZS )MyGPT2LMHeadModel_NoPaddinga  Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and no padding.
    When you always use batch_size=1 in inference, there is no padding in inputs. In such case, position_ids
    and attention_mask need no be in inputs.
    c                    s   t  | d S r   r   r   r   r   r    r   y   s    z$MyGPT2LMHeadModel_NoPadding.__init__c                    s"   t  j||dd}t|| jjS )NF)rB   r"   rC   )r   r%   rF   r;   r   r   r    r#   |   s    z#MyGPT2LMHeadModel_NoPadding.forwardr&   r   r   r   r    rI   s   s   rI   logitsTF
last_state)r   ZGPT2LMHeadModel_NoPaddingr   c                   @   s8   e Zd Zdd ZedddZedddZdd	 Zd
S )
Gpt2Inputsc                 C   s   || _ || _|| _|| _d S r   )r%   r@   rA   rF   )r   r%   r@   rA   rF   r   r   r    r      s    zGpt2Inputs.__init__)returnc                 C   s0   dd | j | j| jfD }| jr,|| j |S )Nc                 S   s   g | ]}|d k	r|qS r   r   .0vr   r   r    
<listcomp>   s      z&Gpt2Inputs.to_list.<locals>.<listcomp>)r%   r@   rA   rF   extend)r   
input_listr   r   r    to_list   s    zGpt2Inputs.to_listc                 C   s"   t dd | j| j| j| jfD S )Nc                 s   s   | ]}|d k	r|V  qd S r   r   rN   r   r   r    	<genexpr>   s      z&Gpt2Inputs.to_tuple.<locals>.<genexpr>)r3   r%   r@   rA   rF   )r   r   r   r    to_tuple   s    zGpt2Inputs.to_tuplec                 C   sT   d }| j d k	r2| j jtjkr,| j jtjdn| j }dd | jD }t| j| j	||S )Ndtypec                 S   s   g | ]}|j tjd qS )rW   )tor9   float32rO   pr   r   r    rQ      s     z&Gpt2Inputs.to_fp32.<locals>.<listcomp>)
rA   rX   r9   float16rY   rZ   rF   rL   r%   r@   )r   rA   rF   r   r   r    to_fp32   s    
zGpt2Inputs.to_fp32N)	r'   r(   r)   r   r   rT   r   rV   r^   r   r   r   r    rL      s   rL   c                   @   s^  e Zd ZdZedddejejejdfeeeeeeeeje	e	e	ej
ej
ej
e	edddZedIeeeeeeeee f dd	d
Zedd ZedJddZedKddZedLddZedMddZeddddejejejfee	e	e	e	ej
ej
ej
dddZedNddZedddd d!gfeee d"d#d$ZedOeed%d&d'ZedPeed%d(d)Zed*d+ ZedQd,d-ZedReeeejf eeee f ee	e	d.d/d0Z ed1d2 Z!ed3d4 Z"edd5d5d6d7ddddejejejdddfd8d9Z#edd:ddddejejejd;d7d<fd=d>Z$edSd?d@Z%eddddAdBdCdDgfedEdFdGZ&dHS )T
Gpt2HelperzEA helper class for Gpt2 model conversion, inference and verification.FT)
batch_sizepast_sequence_lengthsequence_lengthnum_attention_headshidden_sizer<   
vocab_sizedevicer]   has_position_idshas_attention_maskinput_ids_dtypeposition_ids_dtypeattention_mask_dtypeleft_side_paddingrM   c                    s$  |r
t jnt jd| ||t|| g fddt|D }t jd|d | |f| d}d}|
r|| }t j| |g| d}|dkrt| D ]>}td|d }|rd||d|f< qd|||| df< qd}|	r| 	d	d }|
|dk d |dd|df |}t||||S )
zCreate random inputs for GPT2 model.
        Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors.
        r0   c                    s$   g | ]}t j d d d qS )rX   rf   g       @      ?)r9   Zrand)rO   _rf   Z
float_typeZ
past_shaper   r    rQ      s     z/Gpt2Helper.get_dummy_inputs.<locals>.<listcomp>r   r/   )lowhighsizerX   rf   Nrm   )r9   r]   rZ   intr7   randintZonesrandomlongZcumsumZmasked_fill_rY   rL   )r`   ra   rb   rc   rd   r<   re   rf   r]   rg   rh   ri   rj   rk   rl   rF   r%   rA   Ztotal_sequence_lengthr=   Zpadding_lengthr@   r   rp   r    get_dummy_inputs   sF    
zGpt2Helper.get_dummy_inputsr   )r`   ra   rb   r   model_classrM   c                 C   s~   |j }|j}|j}|j}t| d }	| ||	dkr4|n|g}
d| ||| t|| g}|	|
i}t|D ]}||dt| < qd|S )zAReturns a dictionary with output name as key, and shape as value.r/   rJ   r0   present_)rc   rd   Znum_hidden_layersre   MODEL_CLASSESru   r7   str)r`   ra   rb   r   rz   rc   rd   r<   re   Zoutput_nameZlast_state_shapeZpresent_state_shapeoutput_shapesr=   r   r   r    get_output_shapes   s&    	
zGpt2Helper.get_output_shapesc                 C   sZ   |D ]P}|| kst | | }t|| | krtjt|| |j|jd| |< qd S )Nrm   )r6   numpyprodZnelementr9   emptyrX   rf   )output_buffersr~   keybufferr   r   r    auto_increase_buffer_size  s    z$Gpt2Helper.auto_increase_buffer_sizec                 C   sD   |r
t jnt j}i }|  D ]"\}}t jt|||d||< q|S )zpReturns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape.rm   )r9   r]   rZ   itemsr   r   r   )r~   rf   
is_float16Z	data_typer   nameshaper   r   r    get_output_buffers  s
    zGpt2Helper.get_output_buffersc                 C   sL   | d    }t||d  }|r>t|t|d  S t|S dS )zGReturns the maximum difference between PyTorch and OnnxRuntime outputs.r   ư>N)cpur   absamax)torch_outputsort_outputsrelativeZexpected_outputsdiffr   r   r    diff_outputs&  s
    zGpt2Helper.diff_outputsMbP?c           
   	   K   s   t j|d | d    ||d}td|  |}t|d }t|D ]R}t j|d|  | d |    ||d}td| d| d|  |o|}qJ|st| |}	t	d|	d	 |S )
zReturns True if torch and ORT outputs are close for given thresholds, and False otherwise.
        Note: need kwargs since Gpt2BeamSearchHelper.compare_outputs has an extra parameter model_class
        r   )rtolatolz9PyTorch and OnnxRuntime output 0 (last_state) are close: r/   zPyTorch and OnnxRuntime layer z state (present_z) are close:z@PyTorch and OnnxRuntime results are not all close: max_abs_diff=.5f)
r   allcloser   loggerdebugr5   r7   r_   r   info)
r   r   r   r   kwargsis_closeis_all_closeZ
num_layerslayermax_abs_diffr   r   r    compare_outputs0  s"    "

zGpt2Helper.compare_outputsr   c                 C   s  d}d}g }g }t t|D ]}|| }|dkr:| d n| d |d    }	tj||	|dd}
|tt|	|  |o|
}t|		 rt
d| d t|		 rt
d| d t|	 rt
d	| d t|	 rt
d	| d t||	 }t| |j}|d
|| dd| d|| ddt|	| d |dkrttj|dd|j}ttj|	dd|	j}t||}q|t|}|t||||fS )a  Compare outputs from PyTorch and OnnxRuntime

        Args:
            torch_outputs (Tuple[Torch.Tensor]): PyTorch model output
            ort_outputs (List[numpy.ndarray]): OnnxRuntime output
            atol (float, optional): Absolute tollerance. Defaults to 1e-06.

        Returns:
            is_all_close(bool): whether all elements are close.
            max_abs_diff(float): maximum absolute difference.
            messages(str): a list of debug message for each output
        TFr   r/   )r   r   zPyTorch output z has nanz has infzORT output zdiff=z.9fz index=z ort=z torch=N)Zaxis)r7   r5   r   r   r   r8   r   r   isnananyr   r   isinffabsZunravel_indexZargmaxr   floatZarray_equalindexmax)r   r   r   r   is_top1_matchedZ	max_diffsmessagesr=   Z
ort_outputZtorch_outputr   r   idxZort_max_indexZtorch_max_indexmax_diff_output_indexr   r   r    compare_outputs_v2K  sD    (0zGpt2Helper.compare_outputs_v2)onnx_model_pathverboseuse_external_data_formatrg   rh   ri   rj   rk   c
                 C   s  | j }
|
j}tjddd|
j|
j||
j|d|||||	d}| }t	  | | }W 5 Q R X dd t
|D }dd t
|D }|d jd |
jks|d jd |
jkst|d jd |
jkrd	nd
f|}dddd|d dddi}|D ]}ddd||< q|D ]}ddd||< qdg}|r@ddd|d< |d |r^ddd|d< |d || t|dkrt|d |ksttd|jj d|jd j d|d j d|d d j  t|jjddd |r^t j}tj|d}t|jjddd t| t||d|||ddd|d tj|dd} t j!| |ddd W 5 Q R X n"t| t||d|||ddd|d dS ) z1Export GPT-2 model with past state to ONNX model.r/   F)r`   ra   rb   rc   rd   r<   re   rf   r]   rg   rh   ri   rj   rk   c                 S   s   g | ]}d | qS )past_r   rO   r=   r   r   r    rQ     s     z*Gpt2Helper.export_onnx.<locals>.<listcomp>c                 S   s   g | ]}d | qS )r{   r   r   r   r   r    rQ     s     r   r0   rJ   rK   r%   r`   Zseq_len)r   r/   Zpast_seq_len)r/      Ztotal_seq_lenr@   rA   zShapes: input_ids=z past=z output=z	 present=T)parentsexist_okz	gpt2.onnx   )
argsfZexport_paramsinput_namesoutput_namesdynamic_axesZopset_versionZdo_constant_foldingr   r   )Zload_external_data)Zsave_as_external_dataZall_tensors_to_one_fileN)"r   rD   r_   ry   rc   rd   re   rT   r9   no_gradr7   r   r6   r8   rR   r5   r   r   r%   rF   r   parentmkdirtempfileTemporaryDirectoryospathjoinr   r3   onnxZ
load_modelr   save)modelrf   r   r   r   rg   rh   ri   rj   rk   r   r<   dummy_inputsrS   outputsZ
past_namesZpresent_namesr   r   r   r   Ztmp_dir_nameZtemp_onnx_model_pathr   r   r    export_onnx  s    
,"  


$6
zGpt2Helper.export_onnxr   c              	   K   sf   t d}	t| d||d|	dd}
|rV|r2t|
 n$d|krBd|d< |
jf ddi| |
|| |
S )zHOptimize ONNX model with an option to convert it to use mixed precision.r   r   F)
model_typeZ	num_headsrd   	opt_leveloptimization_optionsZuse_gpukeep_io_typesuse_symbolic_shape_inferT)r	   r   r_   auto_mixed_precisionconvert_float_to_float16Zsave_model_to_file)r   Zoptimized_model_pathr   rc   rd   r   r   stager   r   mr   r   r    optimize_onnx  s$    
zGpt2Helper.optimize_onnxAddZLayerNormalizationZSkipLayerNormalizationZFastGeluZEmbedLayerNormalization)
onnx_modelop_block_listc                 C   sT  dd |   D }t|}||}td| d|  |  jd j}d}|  }||ksbt	|| }d}	|j
dkr|}	td	|j  d}
|jD ]}| |}
|
dk	r qqt|
}td
|j d|  |dk }ntd|j
 d|j  g }g }|s|	dk	r|g}|	jg}||||d}td|  | jf ddi| |S )a?  Convert GPT-2 model to mixed precision.
           It detects whether original model has fp16 weights, and set parameters for float16 conversion automatically.
        Args:
            onnx_model (OnnxModel): optimized ONNX model
            op_block_list (List[str], optional): operators to compute in fp32. Defaults to ["Add", "LayerNormalization",
                                                 "SkipLayerNormalization", "FastGelu", "EmbedLayerNormalization"]
        Returns:
            parameters(dict): a dictionary of parameters used in float16 conversion
        c                 S   s   h | ]
}|j qS r   )op_type)rO   noder   r   r    	<setcomp>1  s     z2Gpt2Helper.auto_mixed_precision.<locals>.<setcomp>z	fp32 op: z
 fp16 op: r   FNZMatMulz#Found last MatMul node for logits: z3max diff of converting weights in last MatMul node : r   z-Failed to find MatMul node for logits. Found z	 of node )r   r   node_block_listZforce_fp16_initializersz!auto_mixed_precision parameters: r   T)Znodesset
differencer   r   graphoutputr   output_name_to_noder6   r   inputZget_initializerr   r   warningr   )r   r   Zop_full_setZfp32_op_setZfp16_op_setZlogits_output_nameZis_weight_fp16_precisionr   r   Zlast_matmul_nodeZinitializerr   Zmax_diffr   r   
parametersr   r   r    r     sF    




zGpt2Helper.auto_mixed_precision)inputs
total_runsc           	   	   C   s   t d |  }t  | | }W 5 Q R X |dkr>|S g }t 6 t|D ]&}t }| | }|t |  qTW 5 Q R X t	|d t
| }t dt|d ||fS )zfRun inference of PyTorch model, and returns average latency in ms when total_runs > 0 besides outputs.zstart pytorch_inferencer     zPyTorch inference time = {} ms.2f)r   r   r^   rT   r9   r   r7   timer8   sumr5   format)	r   r   r   rS   r   latencyro   startaverage_latencyr   r   r    pytorch_inferencec  s    


zGpt2Helper.pytorch_inferencec                 C   s"  t d dt|j  i}|jdk	r\t|jD ]$\}}t|  |d| < q6|jdk	r~t|j  |d< |j	dk	rt|j	  |d< | 
d|}|dkr|S g }t|D ]*}t }	| 
d|}|t |	  qt|d t| }
t d	t|
d
 ||
fS )zcRun inference of ONNX model, and returns average latency in ms when total_runs > 0 besides outputs.zstart onnxruntime_inferencer%   Nr   rA   r@   r   r   z"OnnxRuntime Inference time = {} msr   )r   r   r   Zascontiguousarrayr%   r   rF   	enumeraterA   r@   runr7   r   r8   r   r5   r   )ort_sessionr   r   Z
ort_inputsr=   Zpast_ir   r   ro   r   r   r   r   r    onnxruntime_inference}  s(    



z Gpt2Helper.onnxruntime_inferencec              	   C   s   t | ||||||S )z)Returnas IO binding object for a session.)r
   prepare_io_binding)r   r%   r@   rA   rF   r   r~   r   r   r    r     s    zGpt2Helper.prepare_io_bindingc                 C   s   t | |||S )z3Copy results to cpu. Returns a list of numpy array.)r
   "get_outputs_from_io_binding_buffer)r   r   r~   return_numpyr   r   r    r     s       z-Gpt2Helper.get_outputs_from_io_binding_buffer)r   r   r~   r   r   include_copy_output_latencyc              	   C   s   t d t| |j|j|j|j||}| | t	| |||}|dkrN|S g }	t
|D ]<}
t }| | |rt	| |||}
|	t |  qZt|	d t|	 }t d| ||fS )zUInference with IO binding. Returns outputs, and optional latency when total_runs > 0.z*start onnxruntime_inference_with_binded_ior   r   z4OnnxRuntime with IO binding inference time = %.2f ms)r   r   r_   r   r%   r@   rA   rF   Zrun_with_iobindingr   r7   r   r8   r   r5   )r   r   r   r~   r   r   r   Z
io_bindingr   r   ro   r   r   r   r   r    $onnxruntime_inference_with_binded_io  sD    

   
   z/Gpt2Helper.onnxruntime_inference_with_binded_ioc              	   C   s|   t d|  dd}t|| W 5 Q R X td|  d t d|  dd}t|| W 5 Q R X td|  d d S )NZort_outputs_.picklewbz$ORT output are saved to ort_outputs_Ztorch_outputs_z(Torch output are saved to torch_outputs_openpickledumpr   r   )r=   r   r   r   r   r   r    save_outputs  s    zGpt2Helper.save_outputsc              	   C   s@   t d|  dd}t|| W 5 Q R X td|  d d S )NZdummy_inputs_r   r   z!inputs are saved to dummy_inputs_r   )r=   r   r   r   r   r   r   r    save_inputs  s    zGpt2Helper.save_inputsr   i'  r/   c           ,         s0  |j }td| d d| d| d|	 d| d d}d	}d
}d}|rjt|||||	}t|||}d}d}g  dg| }| }t|D ]}t| }t	d|}|dkrdn
t	d|}t	d|} t
d|  d| d tj| |||j|j|j|j|||
||||dd}!t||!}"|r:t| |!}#n"t| ||||	}$t| |!||$}#tj|"|#|d\}%}&}'}(})t|&s |& |%r|d7 }|)r|d7 }||  d7  < |r&|%s&td| d|  d| d| d|& 
 t|(D ]0\}}*td| d|  | j d|*  q|rt|&sB|&d| krt||! t||#|" q rx fdddD }+ndd dD }+|d  | |+d!< fd"d#|D |+d$< |d  | |+d%< |t  d  | |+d&< td'| d(| d)|t   d*|  |d+| kr,td,t|d | d-d. |+S )/zKGenerate random inputs and compare the results of PyTorch and Onnx Runtime.zRunning parity test (atol=z, test_cases=z, runs=z, use_io_binding=z, model_class=z, is_float16=z) ...      r0   Nr   r/   z#Running parity test for batch_size=z past_sequence_length=z...T)ri   rj   rk   rl   )r   z
test_case=z batch_size=z sequence_length=z	 MaxDiff=	z: Name=z, d   c                    s$   i | ]}d | t  |dqS )max_diff_percentile_r   )r   Z
percentiler[   )max_abs_diff_listr   r    
<dictcomp>p  s     z*Gpt2Helper.test_parity.<locals>.<dictcomp>)2   Z   _   c   c                 S   s   i | ]}d | dqS )r  nanr   r[   r   r   r    r  t  s      rn   Ztop1_match_ratec                    s   g | ]}|d    qS )rn   r   )rO   x)test_cases_per_runr   r    rQ   w  s     z*Gpt2Helper.test_parity.<locals>.<listcomp>Ztop1_match_rate_per_runZdiff_pass_rateZnan_ratezParity Test Cases=z	; Passed=z; Nan=z; Top1_Matched=gffffff?zParity is good: passed rate=z.0f%)r   r   r   r_   r   r   r7   ru   rw   rv   r   ry   rc   rd   rD   re   r   r   r   r   r   r   r8   r   get_outputsr   r   r   r5   ),r   r   rf   r   r   r   r  r   use_io_bindingrz   rg   rh   ri   rj   rk   r   r   Zenable_pickle_outputr   Zmax_batch_sizeZmax_past_seq_lenZmax_seq_lenr   Zmax_output_shapesZpassed_test_casesZtop1_matched_casesZtop1_matched_cases_per_runZtotal_test_casesr=   Zrun_idrb   ra   r`   r   r   r   r~   r   r   r   r   r   messager;   r   )r  r  r    test_parity  s    (    
   

 *
" zGpt2Helper.test_parityr  r      c                 C   s   |j }d}|r.t|||||}t|||}tj||||j|j|j|j|||||	|
|d}|rtt	| ||\}}nt
| ||||\}}|S )zCGenerate random inputs and measure average latency of Onnx Runtime.N)ri   rj   rk   )r   r_   r   r   ry   rc   rd   rD   re   r   r   )r   r   rf   r   r   r  rz   rg   rh   ri   rj   rk   r`   rb   ra   r   r   r~   r   ro   r   r   r   r    test_performance  sJ            zGpt2Helper.test_performancec                 C   s:   t jddd|j|j|j|j|d||d }tj	| |S )zJIT trace for TorchScript.r/   F)r`   ra   rb   rc   rd   r<   re   rf   r]   rg   rh   )
r_   ry   rc   rd   rD   re   rT   r9   Zjittrace)r   r   rf   rg   rh   rS   r   r   r    torchscript  s    
zGpt2Helper.torchscriptrawfp32fp16int8)rz   c                 C   s  |}t j|r t|jd }n|dd  |dkrB|d| 7 }|rN|d7 }|rdddd	d
}d
D ]}t j| |||  }	t j|	rf||krzt	|	 t
d|	  W n: tk
r }
 zt
d|	 d|
j  W 5 d}
~
X Y nX qft
d| d|	  qft jt j| ||d t jt j| |d |d t jt j| |d |d t jt j| |d	 |d d
S t j| |d t j| |d t j| |d t j| |d d
S )z=Build a  path name for given model based on given attributes.rt   /r   ro   Z_past Z_fp32Z_fp16Z_int8)r  r  r  r  zRemoved the existed directory: zFailed to remove the directory r   NzDirectory for z
 existed: z.onnxz
_fp32.onnxz
_fp16.onnxz
_int8.onnx)r   r   isdirr   partssplitr   existsshutilrmtreer   r   OSErrorstrerror)
output_dirZmodel_name_or_pathrz   Zhas_pastZ
new_folderZremove_existingZ
model_namesuffixr   Znew_direr   r   r    get_onnx_paths  sN    

,zGpt2Helper.get_onnx_pathsN)r   )F)F)r   r   )r   )FFr   )r   )r   )T)r   TF)TT)'r'   r(   r)   r*   rG   r9   Zint32ru   rf   boolrX   rL   ry   r   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   ZTensorr   r   r   r  r  r  r)  r   r   r   r    r_      s&  
@ "
		5w   #E
   2
	
 6
r_   )4loggingr   r   rw   r"  r   r   pathlibr   typingr   r   r   r   r   r   r9   Zbenchmark_helperr   r]   r   Zfusion_optionsr	   Zio_binding_helperr
   r   r   Z	optimizerr   Ztorch_onnx_export_helperr   Ztransformersr   r   r   r   	getLoggerr'   r   ZPRETRAINED_GPT2_MODELSZFLOAT32ZFLOAT16ZINT8ZDEFAULT_TOLERANCEr   r,   r.   rH   rI   r|   rL   r_   r   r   r   r    <module>   sN   
   
$