U
    Mâhµ  ã                !   @   s"  U d Z ddlZddlZddlZddlZddlZddlZddlmZ ddl	m
Z
 ddl	mZ ddlmZmZmZmZ ddlmZmZ dd	lmZ G d
d„ deƒZG dd„ dƒZeeef ZdZi Zeed< dddœZddddddddddddddd d!d"d#d$d%d&d'd'd(d)d*d+d,d-d.ddd/g!ZG d0d„ deƒZ d1d„ Z!G d2d„ dƒZ"G d3d„ dƒZ#eee$d4œd5d„Z%d6d7„ Z&dieeeee$e$ee d:œd;d„Z'eeeeeeef e$e$e#e$dd<œ
d=d„Z(d>d„ Z)d?d„ Z*e +d@¡Z,dAd„ Z-dBd„ Z.dCd„ Z/dDd„ Z0e +dE¡Z1dFd„ Z2e +dG¡Z3dHd „ Z4e +dI¡Z5dJd!„ Z6djdKd"„Z7dLd#„ Z8dMd$„ Z9dNd%„ Z:dOd&„ Z;dPd'„ Z<G dQdR„ dRƒZ=G dSd(„ d(ƒZ>e>ƒ Z?i Z@e>ƒ ZAi ZBeeeCf edT< i ZDe
D ]ªZEeFeEeƒs¬tG‚eE H¡ D ]Š\ZIZJeJd ZKeJdd… ZLejMeLkreA NeI¡ ejOeLkr
eB PeIdU¡r
eKeDeI< neKeBeI< ejQeLkr´ejOeLkr´e? NeI¡ eKe@eI< q´q˜e +e? R¡ ¡ZSe +dVeA R¡ › dW¡ZTe +dX¡ZUe +dY¡ZVe +dZ¡ZWe +d[¡ZXeeeeeeef e$e$e#e$ed<œ
d\d)„ZYdkd]d*„ZZd^d+„ Z[d_d,„ Z\e +d`¡Z]dad-„ Z^dbd.„ Z_dlee$eeeeeee$ee$e$e$e$ee# edgœdhd/„Z`dS )ma   The Python Hipify script.
##
# Copyright (c) 2015-2016 Advanced Micro Devices, Inc. All rights reserved.
#               2017-2018 Advanced Micro Devices, Inc. and
#                         Facebook Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
é    Né   )Ú	constants)ÚCUDA_TO_HIP_MAPPINGS)ÚMATH_TRANSPILATIONS)ÚDictÚListÚIteratorÚOptional)ÚMappingÚIterable)ÚEnumc                   @   s   e Zd ZdZdZdS )ÚCurrentStater   é   N)Ú__name__Ú
__module__Ú__qualname__ÚINITIALIZEDÚDONE© r   r   úR/var/www/html/venv/lib/python3.8/site-packages/torch/utils/hipify/hipify_python.pyr   *   s   r   c                   @   s   e Zd Zdd„ Zdd„ ZdS )ÚHipifyResultc                 C   s   || _ || _d| _d S )NÚ ©Úcurrent_stateÚhipified_pathÚstatus)Úselfr   r   r   r   r   Ú__init__/   s    zHipifyResult.__init__c                 C   s   d| j › d| j› d| j› S )NzHipifyResult:: current_state: z, hipified_path : z
, status: r   ©r   r   r   r   Ú__str__4   s    zHipifyResult.__str__N)r   r   r   r   r   r   r   r   r   r   .   s   r   z;// !!! This is a file automatically generated by hipify!!!
ÚHIPIFY_FINAL_RESULTZscalar_t)ZDtypeÚTÚ
InputErrorÚopenfÚbcolorsÚGeneratedFileCleanerÚmatch_extensionsÚmatched_files_iterÚpreprocess_file_and_save_resultÚcompute_statsÚadd_dim3ÚprocessKernelLaunchesÚfind_closure_groupÚfind_bracket_groupÚfind_parentheses_groupÚreplace_math_functionsÚhip_header_magicÚreplace_extern_sharedÚget_hip_file_pathÚis_out_of_placeÚis_pytorch_fileÚis_cusparse_fileÚis_special_fileÚis_caffe2_gpu_fileÚTrieÚpreprocessorÚfile_specific_replacementÚfile_add_headerÚfix_static_global_kernelsÚextract_argumentsÚstr2boolÚhipifyc                       s$   e Zd Z‡ fdd„Zdd„ Z‡  ZS )r"   c                    s   t ƒ  |¡ || _d S ©N)Úsuperr   Úmessage)r   rB   ©Ú	__class__r   r   r   K   s    zInputError.__init__c                 C   s   d| j › S )NzInput error: )rB   r   r   r   r   r   O   s    zInputError.__str__)r   r   r   r   r   Ú__classcell__r   r   rC   r   r"   H   s   c                 C   s   t | |ddS )NÚignore)Úerrors)Úopen)ÚfilenameÚmoder   r   r   r#   S   s    c                   @   s,   e Zd ZdZdZdZdZdZdZdZ	dZ
d	S )
r$   z[95mz[94mz[92mz[93mz[91mz[0mz[1mz[4mN)r   r   r   ÚHEADERZOKBLUEÚOKGREENÚWARNINGZFAILÚENDCZBOLDZ	UNDERLINEr   r   r   r   r$   X   s   c                   @   s<   e Zd ZdZddd„Zdd„ Zdd„ Zdd	d
„Zdd„ ZdS )r%   z+Context Manager to clean up generated filesFc                 C   s   || _ tƒ | _g | _d S r@   )Úkeep_intermediatesÚsetÚfiles_to_cleanÚdirs_to_clean)r   rO   r   r   r   r   m   s    zGeneratedFileCleaner.__init__c                 C   s   | S r@   r   r   r   r   r   Ú	__enter__r   s    zGeneratedFileCleaner.__enter__c                 O   s0   t j |¡s | j t j |¡¡ t|f|ž|ŽS r@   )ÚosÚpathÚexistsrQ   ÚaddÚabspathrH   )r   ÚfnÚargsÚkwargsr   r   r   rH   u   s    zGeneratedFileCleaner.openc                 C   sx   t j |¡\}}|s$t j |¡\}}|rF|rFt j |¡sF| j|dd t j |¡rV|stt  |¡ | j t j 	|¡¡ d S )NT)Úexist_ok)
rT   rU   ÚsplitrV   ÚmakedirsÚisdirÚmkdirrR   ÚappendrX   )r   Údnr\   ÚparentÚnr   r   r   r^   z   s    
zGeneratedFileCleaner.makedirsc                 C   s@   | j s<| jD ]}t |¡ q| jd d d… D ]}t |¡ q,d S )Néÿÿÿÿ)rO   rQ   rT   ÚunlinkrR   Úrmdir)r   ÚtypeÚvalueÚ	tracebackÚfÚdr   r   r   Ú__exit__„   s
    
zGeneratedFileCleaner.__exit__N)F)F)	r   r   r   Ú__doc__r   rS   rH   r^   rm   r   r   r   r   r%   k   s   


)rI   Ú
extensionsÚreturnc                    s   t ‡ fdd„|D ƒƒS )z<Helper method to see if filename ends with certain extensionc                 3   s   | ]}ˆ   |¡V  qd S r@   ©Úendswith)Ú.0Úe©rI   r   r   Ú	<genexpr>Ž   s     z#match_extensions.<locals>.<genexpr>©Úany)rI   ro   r   ru   r   r&   Œ   s    c                    s   t ‡ fdd„|D ƒƒS )Nc                 3   s   | ]}t   ˆ |¡V  qd S r@   )Úfnmatch)rs   Úpattern©Úfilepathr   r   rv   ’   s     z_fnmatch.<locals>.<genexpr>rw   )r|   Úpatternsr   r{   r   Ú_fnmatch‘   s    r~   r   F)Ú	root_pathÚincludesÚignoresro   Úout_of_place_onlyÚis_pytorch_extensionrp   c                 c   sò   t |ƒ}tj| ddD ]Ö\}}}	tj || ¡}
|
dkrvd|krH| d¡ d|krZ| d¡ d|krv| d¡ | d¡ |	D ]p}tj ||¡}tj |
|¡}t||ƒrzt||ƒszt	||ƒsÀ||krz|sät
|ƒsÖt|ƒsÖqz|rät|ƒsäqz|V  qzqd S )NT)ÚtopdownÚ.z.gitÚbuildZthird_partyzthird_party/nvfuser)rP   rT   ÚwalkrU   ÚrelpathÚremovera   Újoinr~   r&   r4   r7   r3   )r   r€   r   ro   r‚   rƒ   Úexact_matchesZabs_dirpathÚdirsÚ	filenamesZrel_dirpathrI   r|   Úrel_filepathr   r   r   r'   •   s8    



ÿþýý)
Úoutput_directoryr|   Ú	all_filesÚheader_include_dirsÚstatsÚhip_clang_launchrƒ   Ú	clean_ctxÚshow_progressrp   c	              
   C   st   t j t j | |¡¡}	ttj|	d}
|
t|	< t| ||||||||ƒ	}|rhd|j	krht
|	d|j|j	dd |t|	< d S )N)r   r   Zignoredz->T)Úflush)rT   rU   rX   rŠ   r   r   r   r    r9   r   Úprintr   )r   r|   r   r‘   r’   r“   rƒ   r”   r•   Úfin_pathÚhipify_resultÚresultr   r   r   r(   Á   s$    
   ÿ   þc                 C   sP   dd„ | d D ƒ}t dt|ƒd›ƒ t d |¡ƒ t dt| d ƒd›ƒ d S )	Nc                 S   s   h | ]\}}|’qS r   r   )rs   Z	cuda_callZ	_filepathr   r   r   Ú	<setcomp>Û   s     z compute_stats.<locals>.<setcomp>Úunsupported_callsz1Total number of unsupported CUDA function calls: rl   ú, z+
Total number of replaced kernel launches: Úkernel_launches)r—   ÚlenrŠ   )r’   rœ   r   r   r   r)   Ú   s    c                 C   s¦  d}d}|   dd¡  dd¡} dd„ tdƒD ƒ}d|| d< t| ƒD ]Š\}}|d	krV qÎ|d
krh|d	7 }n|dkrx|d	8 }|dks|t| ƒd	 krB|dkrB||dk || d< |d	7 }|dk rB|d	 || d< qB| |d d |d d d	 … }| |d	 d |d	 d … }| |d d |d d …   dd¡ d¡}	| |d	 d |d	 d …   dd¡ d¡}
d|	› d}d|
› d}|  |	|¡}|  |
|¡}|  || || ¡}|S )zBadds dim3() to the second and third arguments in the kernel launchr   ú<<<r   ú>>>c                 S   s   g | ]}i ‘qS r   r   )rs   Ú_r   r   r   Ú
<listcomp>ì   s     zadd_dim3.<locals>.<listcomp>r   Ústartr   ú(ú)ú,ÚendÚ
ú zdim3()ÚreplaceÚrangeÚ	enumeraterŸ   Ústrip)Úkernel_stringÚcuda_kernelÚcountÚclosureZarg_locsÚindÚcZfirst_arg_rawZsecond_arg_rawZfirst_arg_cleanZsecond_arg_cleanZfirst_arg_dim3Zsecond_arg_dim3Zfirst_arg_raw_dim3Zsecond_arg_raw_dim3r   r   r   r*   ç   s6    
  **z([ ]+)(detail?)::[ ]+\\\n[ ]+c                    sV  t  dd„ ˆ ¡‰ ‡ fdd„}dd„ }dd„ }t||ˆ ƒƒƒ}ˆ }|D ]
}||ƒ}ˆ  d	|d
 ¡}	ˆ |d d |	d … }
ˆ |d |d
 … }|d d
 dkr¢dnd}ˆ |d d || d
 d … }t||
ƒ}ttd|d  dd	¡ dd¡ƒƒ}d|dd…  ddd|  d ¡ dd¡ dd¡ |d	| d ¡ }| |
|¡}|d  |¡ qD|S )zK Replace the CUDA style Kernel launches with the HIP style kernel launches.c                 S   s   |   d¡› |   d¡› dS )Nr   r   z::©Úgroup©Zinpr   r   r   Ú<lambda>  ó    z'processKernelLaunches.<locals>.<lambda>c           
         s„  | d | d dœdddœdddœdœ}ddi}d}d}d	}d
}|}t |d d d ddƒD ]"}ˆ | }	|||fkrà|	dkr¤||kr”|}||d d< |d  d7  < |	dkrà|d  d8  < |d dkrà||krà||d d< |}||krZˆ |  ¡ sˆ | dkrP||kr|}||d d< |dkr~d|d d< |d |d |d g  S qZ||krZ||d d< |d |d |d g  S qZd S )Nr¤   r¨   ©r¤   r¨   re   )Úkernel_launchÚkernel_nameÚtemplatez<>r   r   r   é   r»   ú>r½   ú<>   ú#r¢   r¦   r¥   ú:r¼   )r¬   Úisalnum)
Z	in_kernelÚposr±   ZSTARTZAT_TEMPLATEZAFTER_TEMPLATEZAT_KERNEL_NAMEr   ÚiÚchar©Ústringr   r   Úgrab_method_and_template  sD    ý

z7processKernelLaunches.<locals>.grab_method_and_templatec                 S   sd   d}g }|   d|¡dkr`|   d|¡}|   d|¡d }|dkrDtdƒ‚| ||| ||… dœ¡ q|S )zKFinds the starting and ending points for all kernel launches in the string.r   r    re   r¡   r¾   zno kernel end found)r¤   r¨   r¶   )Úfindr"   ra   )rÈ   Z
kernel_endZkernel_positionsZkernel_startr   r   r   Úfind_kernel_boundsM  s    
ÿ
z1processKernelLaunches.<locals>.find_kernel_boundsc                 S   sâ   d}d}d}| D ]Ì}|dkrf|dkr2|dkr2d}q¾|dkrH|dkrHd}q¾|dkr¾|dkr¾|dkr¾d}nX|dkr„|d	ks~|d
kr¾d}n:|dkr¢|dkr¾|dkr¾d}n|dkr¾|dkr¾|dkr¾d}|}|dkrÔ||7 }q|d7 }q|S )Nr   ú/z//Ú*z/*ú"ú\ú'úr©   Úxr   )rÈ   Z
in_commentZprev_cZ
new_stringr´   r   r   r   Úmask_commentse  s2    

z,processKernelLaunches.<locals>.mask_commentsr¥   r¨   r   r¤   r   re   r¶   r    r¡   r¦   zhipLaunchKernelGGL(z, 0é   r   rž   )	ÚRE_KERNEL_LAUNCHÚsubÚlistrÊ   r*   rŸ   r=   r«   ra   )rÈ   r’   rÉ   rË   rÓ   Zget_kernel_positionsÚoutput_stringZkernelÚparamsZparenthesisr°   r¯   Zend_param_indexZkernel_name_with_templateZcuda_kernel_dim3Znum_klpZ
hip_kernelr   rÇ   r   r+     s>    ;!
 
" ÿ ÿ þ 
þc                 C   sŽ   d}d}|}d\}}|t | ƒk rŠ| | |d krP|dkrFd}d}|}q€|d7 }n0| | |d kr€|r€|d8 }|dkr€|}||fS |d7 }qdS )aÊ  Generalization for finding a balancing closure group

         if group = ["(", ")"], then finds the first balanced parentheses.
         if group = ["{", "}"], then finds the first balanced bracket.

    Given an input string, a starting position in the input string, and the group type,
    find_closure_group returns the positions of group[0] and group[1] as a tuple.

    Example:
        >>> find_closure_group("(hi)", 0, ["(", ")"])
        (0, 3)
    Fr   )re   re   Tr   )NN)rŸ   )Úinput_stringr¤   r¶   Zinside_parenthesisZparensrÄ   Zp_startZp_endr   r   r   r,   §  s$    

c                 C   s   t | |ddgdS )z%Finds the first balanced parantheses.Ú{Ú}rµ   ©r,   ©rÚ   r¤   r   r   r   r-   Í  s    c                 C   s   t | |ddgdS )z!Finds the first balanced bracket.r¥   r¦   rµ   rÝ   rÞ   r   r   r   r.   Ò  s    z\bassert[ ]*\(c                 C   s.   | }t D ] }| |› dt | › d¡}q|S )a‹  FIXME: Temporarily replace std:: invocations of math functions
        with non-std:: versions to prevent linker errors NOTE: This
        can lead to correctness issues when running tests, since the
        correct version of the math function (exp/expf) might not get
        called.  Plan is to remove this function once HIP supports
        std:: math function calls inside device code

    r¥   )r   r«   )rÚ   rØ   Úfuncr   r   r   r/   Ú  s    	z:?:?\b(__syncthreads)\b(\w*\()c                    sh   | ‰ ddg}t ‡ fdd„|D ƒƒr&ˆ S dˆ k}|dˆ k7 }|dˆ k7 }|t ˆ ¡dk	7 }|rdd	|  ‰ ˆ S )
a  If the file makes kernel builtin calls and does not include the cuda_runtime.h header,
    then automatically add an #include to match the "magic" includes provided by NVCC.
    TODO:
        Update logic to ignore cases where the cuda_runtime.h is included by another file.
    zhip/hip_runtime.hzhip/hip_runtime_api.hc                 3   s(   | ] }t  d |› d|› dˆ ¡V  qdS )z#include ("z"|<z>)N)ÚreÚsearch)rs   Úext©rØ   r   r   rv   ù  s     z#hip_header_magic.<locals>.<genexpr>ZhipLaunchKernelGGLÚ
__global__Z
__shared__Nz#include "hip/hip_runtime.h"
)rx   ÚRE_SYNCTHREADSrá   )rÚ   ÚheadersZhasDeviceLogicr   rã   r   r0   í  s    zGextern\s+([\w\(\)]+)?\s*__shared__\s+([\w:<>\s]+)\s+(\w+)\s*\[\s*\]\s*;c                 C   s   | }t  dd„ |¡}|S )a  Match extern __shared__ type foo[]; syntax and use HIP_DYNAMIC_SHARED() MACRO instead.
       https://github.com/ROCm-Developer-Tools/HIP/blob/master/docs/markdown/hip_kernel_language.md#__shared__
    Example:
        "extern __shared__ char smemChar[];" => "HIP_DYNAMIC_SHARED( char, smemChar)"
        "extern __shared__ unsigned char smem[];" => "HIP_DYNAMIC_SHARED( unsigned char, my_smem)"
    c                 S   s.   d|   d¡pd› d|   d¡› d|   d¡› dS )	NzHIP_DYNAMIC_SHARED(r   r   rª   r   r   r¾   r¦   rµ   r·   r   r   r   r¸     r¹   z'replace_extern_shared.<locals>.<lambda>)ÚRE_EXTERN_SHAREDrÖ   )rÚ   rØ   r   r   r   r1     s     ÿc                 C   sð   t j | ¡rt‚|s t| ƒs | S t j | ¡\}}t j |¡\}}|dkrLd}|}|}| dd¡}| dd¡}| dd¡}| dd¡}| dd¡}|d	kr¤| dd¡}|s¾||kr¾t j |d¡}|rÞ||krÞ|| |krÞ|d
 }t j ||| ¡S )z3
    Returns the new name of the hipified file
    ú.cuú.hipÚcudaZhipÚCUDAÚHIPÚTHCÚTHHzcaffe2/coreZ_hip)	rT   rU   ÚisabsÚAssertionErrorr3   r]   Úsplitextr«   rŠ   )rŽ   rƒ   ÚdirpathrI   Úrootrâ   Úorig_filenameZorig_dirpathr   r   r   r2     s*    $c                 C   s>   t j | ¡rt‚|  d¡rdS |  d¡r,dS |  d¡r:dS dS )Nútorch/Fúthird_party/nvfuser/útools/autograd/templates/T©rT   rU   rï   rð   Ú
startswith©rŽ   r   r   r   r3   c  s    


c                 C   sZ   t j | ¡rt‚|  d¡r,|  d¡r(dS dS |  d¡r:dS |  d¡rHdS |  d¡rVdS dS )Nzaten/zaten/src/ATen/core/FTrõ   rö   r÷   rø   rú   r   r   r   r4   o  s    




c                 C   s   t | ƒrd|  ¡ kS dS )NÚsparseF©r4   Úlowerrú   r   r   r   r5   ~  s    c                 C   s<   t | ƒr8d|  ¡ krdS d|  ¡ kr8d|  ¡ kr4dS dS dS )Nrû   TZlinalgZbatchlinearalgebralibblasFrü   rú   r   r   r   r6   „  s    c                 C   sR   t j | ¡rt‚|  d¡rdS t j | ¡}t j |¡\}}d|ksJ|dkoPd|kS )Nzc10/cudaTZgpu©rè   ú.cuhZcudnn)rT   rU   rï   rð   rù   Úbasenamerñ   )rŽ   rI   r¢   râ   r   r   r   r7   Ž  s    
c                   @   s   e Zd ZdZdd„ ZdS )ÚTrieNodezA Trie node whose children are represented as a directory of char: TrieNode.
       A special char '' represents end of word
    c                 C   s
   i | _ d S r@   )Úchildrenr   r   r   r   r   ›  s    zTrieNode.__init__N)r   r   r   rn   r   r   r   r   r   r  –  s   r  c                   @   sP   e Zd ZdZdd„ Zdd„ Zdd„ Zdd	„ Zd
d„ Zdd„ Z	dd„ Z
dd„ ZdS )r8   z£Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
    The corresponding Regex should match much faster than a simple Regex union.c                 C   s   t ƒ | _dS )z,Initialize the trie with an empty root node.N)r  ró   r   r   r   r   r   ¢  s    zTrie.__init__c                 C   s8   | j }|D ]}|j |tƒ ¡ |j| }q
d|jd< dS )zAdd a word to the Trie. Tr   N)ró   r  Ú
setdefaultr  ©r   ÚwordÚnoderÆ   r   r   r   rW   ¦  s
    zTrie.addc                 C   s   | j S )zReturn the root node of Trie. )ró   r   r   r   r   Údump¯  s    z	Trie.dumpc                 C   s
   t  |¡S )z Escape a char for regex. )rà   Úescape)r   rÆ   r   r   r   Úquote³  s    z
Trie.quotec                 C   s6   | j }|D ] }||jkr$|j| }q
 dS q
d|jkS )zZSearch whether word is present in the Trie.
        Returns True if yes, else return FalseFr   )ró   r  r  r   r   r   rá   ·  s    
zTrie.searchc           
   	   C   sF  |}d|j kr$t|j  ¡ ƒdkr$dS g }g }d}t|j  ¡ ƒD ]h}t|j | tƒr¢z(|  |j | ¡}| |  |¡| ¡ W q¦ t	k
rž   | |  |¡¡ Y q¦X q>d}q>t|ƒdk }t|ƒdkröt|ƒdkrÞ| |d ¡ n| dd 
|¡ d ¡ t|ƒdkr|d }	ndd 
|¡ d	 }	|rB|r6|	d
7 }	nd|	› d}	|	S )z0Convert a Trie into a regular expression patternr   r   Nr   ú[ú]z(?:ú|r¦   ú?z)?)r  rŸ   ÚkeysÚsortedÚ
isinstancer  Ú_patternra   r	  Ú	ExceptionrŠ   )
r   ró   r  ZaltÚccÚqrÆ   ÚrecurseZcconlyrš   r   r   r   r  Ä  s6    

zTrie._patternc                 C   s   |   | j¡S ©z#Export the Trie to a regex pattern.©r  ró   r   r   r   r   rz   ë  s    zTrie.patternc                 C   s   |   | j¡S r  r  r   r   r   r   Úexport_to_regexï  s    zTrie.export_to_regexN)r   r   r   rn   r   rW   r  r	  rá   r  rz   r  r   r   r   r   r8   ž  s   	'ÚPYTORCH_MAPr   z(?<=\W)(z)(?=\W)z#include "([^"]+)"z#include <([^>]+)>z"#define THC_GENERIC_FILE "([^"]+)"z\.cu\bc	                    s|  t j t j ˆ|¡¡‰tˆ }	|ˆ kr>d|	_d|	_tj|	_	|	S t j 
|ˆ¡}
tˆddH}| ¡ tkrŠd|	_d|	_tj|	_	|	W  5 Q R £ S | d¡ | ¡ }W 5 Q R X |}t j t j ˆt|
ˆƒ¡¡}t j t j |¡¡sìˆ t j |¡¡ dd„ ‰‡fd	d
„}ˆrt ˆ|¡}nDt|
ƒr,t ||¡}n,t|
ƒrDt ˆ|¡}ndd„ }t ||¡}d'‡ ‡‡‡‡‡‡‡‡	f	dd„	}t |ddƒ|¡}t |ddƒ|¡}t |dƒ|¡}| d¡rÚ| dd¡}| dd¡}t d|¡}ˆsêt|ˆ	ƒ}| d¡rd|krt |ƒ}t!|ƒ}ˆrR||krRt j ˆ¡t j |¡krRˆ|	_d|	_tj|	_	|	S ˆ|krpt"ˆdƒrpt| }d}t j |¡r¦t|dd}| ¡ |k}W 5 Q R X |r`z@ˆj|ddd}| #|¡ W 5 Q R X ||	_d|	_tj|	_	|	W S  t$k
r\ } zPt%t&j'› d |› d!|j(› d"ˆ› d#t&j)› 	t*j+d$ ˆ|	_d%|	_tj|	_	|	 W Y ¢S d}~X Y nX n||	_d&|	_tj|	_	|	S dS )(z< Executes the CUDA -> HIP conversion on the specified file. Nz[ignored, not to be hipified]zutf-8)Úencodingz#[ignored, input is hipified output]r   c                 S   s   t |  d¡ S ©Nr   )r  r¶   ©Úmr   r   r   Úpt_replK  s    zpreprocessor.<locals>.pt_replc                    s   t  |  d¡ˆ | ƒ¡S r  )ÚPYTORCH_SPECIAL_MAPÚgetr¶   r  )r  r   r   Úpt_special_replN  s    z%preprocessor.<locals>.pt_special_replc                 S   s   t |  d¡ S r  )Ú
CAFFE2_MAPr¶   r  r   r   r   Úc2_repl[  s    zpreprocessor.<locals>.c2_replTc                    s$   ‡‡‡‡‡‡ ‡‡‡	‡
‡fdd„}|S )Nc              
      sà  |   d¡}tj |¡\}‰ | d¡s8| d¡rN| d¡sNˆ t|   d¡ˆƒ¡S ˆrÖt‡ fdd„ˆD ƒƒrÖd }d }ˆr®tj ˆ¡}tj 	tj 
||¡¡}tj |¡r®|}|}|d krøˆD ]<}tj 
ˆ|¡}tj 	tj 
||¡¡}tj |¡rº|}|}qº|d kr|   d¡S |tkr0tˆ|ˆˆˆ
ˆˆˆˆ	ƒ	 nz|tkrªt| }|jtjkrªtj |ˆ¡}	tj 	tj 
ˆt|	ˆƒ¡¡}
|
|_|t|< ˆ tj |
d k	r |
n||¡¡S t| j}ˆ tj |d k	rÌ|n||¡¡S |   d¡S )Nr   )z	ATen/cudazATen/native/cudazATen/native/nested/cudazATen/native/quantized/cudazATen/native/sparse/cudazATen/native/transformers/cudazTHC/rí   ZTHCPc                 3   s   | ]}|  ˆ ¡V  qd S r@   rq   )rs   Úsru   r   r   rv   p  s     z>preprocessor.<locals>.mk_repl.<locals>.repl.<locals>.<genexpr>r   )r¶   rT   rU   r]   rù   Úformatr2   rx   ÚdirnamerX   rŠ   rV   r    r(   r   r   r   rˆ   r   )r  rk   rò   Z
header_dirZheader_filepathZheader_dir_to_checkZheader_path_to_checkÚheader_include_dirZheader_resultZheader_rel_pathZheader_fout_pathZhipified_header_filepath)r   r”   r˜   r‘   r“   Úinclude_current_dirrƒ   r   r•   r’   Útemplru   r   Úrepla  sr    
ÿøø



     ý
ÿ ÿ
 ÿz+preprocessor.<locals>.mk_repl.<locals>.replr   )r)  r(  r*  )	r   r”   r˜   r‘   r“   rƒ   r   r•   r’   )r(  r)  r   Úmk_repl`  s     :zpreprocessor.<locals>.mk_replz#include "{0}"z#include <{0}>Fz#define THC_GENERIC_FILE "{0}"zCMakeLists.txtrë   rì   rí   rî   ré   rþ   Z	PowKernelz[skipped, no changes])rè   rÿ   ú.cú.ccú.cppú.hú.hppÚwz[ok]zFailed to save z with "z", leaving z unchanged.©Úfilez[skipped, no permissions]z[skipped, already hipified])T),rT   rU   rX   rŠ   r    r   r   r   r   r   rˆ   rH   ÚreadlineÚHIPIFY_C_BREADCRUMBÚseekÚreadr2   rV   r&  r^   ÚRE_PYTORCH_PREPROCESSORrÖ   r6   r4   ÚRE_CAFFE2_PREPROCESSORÚRE_QUOTE_HEADERÚRE_ANGLE_HEADERÚRE_THC_GENERIC_FILErr   r«   ÚRE_CU_SUFFIXr+   r/   r0   r&   ÚwriteÚPermissionErrorr—   r$   rM   ÚstrerrorrN   ÚsysÚstderr)r   r|   r   r‘   r’   r“   rƒ   r”   r•   r™   rŽ   ZfinZoutput_sourceZorig_output_sourceZ	fout_pathr!  r#  r+  Zdo_writeZfout_oldZfoutrt   r   )
r   r”   r˜   r‘   r“   rƒ   r   r  r•   r’   r   r9   %  sž    


<
ÿþý&ÿc              	      st   t | dƒ`}| ¡ }|r>t dt |¡› d‡ fdd„|¡}n| |ˆ ¡}| d¡ | |¡ | ¡  W 5 Q R X d S )Núr+z\b(z)\bc                    s   ˆ S r@   r   )rÒ   ©Úreplace_stringr   r   r¸   á  r¹   z+file_specific_replacement.<locals>.<lambda>r   )	r#   r7  rà   rÖ   r  r«   r6  r>  Útruncate)r|   Zsearch_stringrE  Ústrictrk   Úcontentsr   rD  r   r:   Ý  s    &

c              	   C   sr   t | dƒ^}| ¡ }|d dkr8|d dkr8d|› d}d|› d| }| d¡ | |¡ | ¡  W 5 Q R X d S )	NrC  r   rÀ   re   r¿   rÎ   z	#include z 
)r#   r7  r6  r>  rF  )r|   Úheaderrk   rH  r   r   r   r;   é  s    

c                 C   s   |   dd¡} | S )z<Static global kernels in HIP results in a compilation error.z __global__ staticrä   ©r«   )Zin_txtr   r   r   r<   ô  s    z#include .*\nc                 C   s6  g }dddœ}| }|d }|t |ƒk r2|| dkrF|d  d7  < nt|| dkrd|d  d8  < nV|| dkr‚|d  d7  < n8|| dkrº||d  dkrº|d dkrº|d  d8  < |d dkræ|d dkræ| ||d	œ¡ q2|d dkr(|d dkr(|| d
kr(| ||d	œ¡ |d }|d7 }q|S )ad   Return the list of arguments in the upcoming function parameter closure.
        Example:
        string (input): '(blocks, threads, 0, THCState_getCurrentStream(state))'
        arguments (output):
            '[{'start': 1, 'end': 7},
            {'start': 8, 'end': 16},
            {'start': 17, 'end': 19},
            {'start': 20, 'end': 53}]'
    r   )rÀ   r¥   r   r¥   r¦   rÀ   r¿   ú-rº   r§   )rŸ   ra   )r¤   rÈ   Ú	argumentsZclosuresÚcurrent_positionZargument_start_posr   r   r   r=   ý  s.    þ(*
c                 C   s.   |   ¡ dkrdS |   ¡ dkr dS t d¡‚dS )zArgumentParser doesn't support type=bool. Thus, this helper method will convert
    from possible string types to True / False.)ÚyesÚtrueÚtÚyÚ1T)ÚnoÚfalserk   rd   Ú0FzBoolean value expected.N)rý   ÚargparseÚArgumentTypeError)Úvr   r   r   r>   +  s
    ©rè   rÿ   r,  r-  r.  r/  z.inr0  ©rÿ   r/  r0  ©rÍ   T)Úproject_directoryÚshow_detailedro   Úheader_extensionsr   r‘   r€   Úextra_filesr‚   r   r•   r“   rƒ   Úhipify_extra_files_onlyr”   rp   c                    sò  ˆdkrt  ¡ ‰t j ˆ¡s.tdƒ t d¡ ˆ sDˆ d¡ ˆd ‰ ˆˆ krt‡ ‡fdd„|D ƒ}‡ ‡fdd„|	D ƒ}	t j ˆ ¡sŒt 	ˆˆ ¡ t
tˆ ||	|||d	ƒ}t|ƒ}|D ]0}t j |¡sÎt j ˆ |¡}||kr°| |¡ q°d
dlm} |D ]†}t j |¡r||ƒ}n|t j ˆ |¡ƒ}| d¡D ]L}| ¡ r*tt|ƒ|ƒr*tt|ƒ|	ƒs*t|j|ƒr*| t|ƒ¡ q*qò|d krŽtdd}g g dœ}|s¢|n|D ]}tˆ ||||||||
ƒ	 q¦ttjd tj tjd |rît|ƒ tS )Nr   z,The project folder specified does not exist.r   rÌ   Z_amdc                    s   g | ]}|  ˆˆ ¡‘qS r   rJ  )rs   Úinclude©r   r\  r   r   r£   U  s     zhipify.<locals>.<listcomp>c                    s   g | ]}|  ˆˆ ¡‘qS r   rJ  )rs   rF   rb  r   r   r£   V  s     )r€   r   ro   r‚   rƒ   r   )ÚPathrÍ   T)rO   )rœ   rž   z-Successfully preprocessed all matching files.r2  ) rT   ÚgetcwdrU   rV   r—   rA  ÚexitÚrstripÚshutilÚcopytreer×   r'   rP   rï   rŠ   ra   Úpathlibrc  ÚrglobÚis_filer~   Ústrr&   Únamer%   r(   r$   rL   rN   rB  r)   r    )r\  r]  ro   r^  r   r‘   r€   r_  r‚   r   r•   r“   rƒ   r`  r”   r   Zall_files_setrk   rc  r'  Zheader_include_dir_pathrU   r’   r|   r   rb  r   r?   6  sn    

 ý
ÿþý
ü



    ÿ)r   r   r   FF)F)F)FrY  rZ  r   r   r[  r   Fr   TFFFN)arn   rV  ry   rà   rg  rA  rT   r   r   Zcuda_to_hip_mappingsr   r   Útypingr   r   r   r	   Úcollections.abcr
   r   Úenumr   r   r   rl  ZHipifyFinalResultr5  r    Ú__annotations__ZPYTORCH_TEMPLATE_MAPÚ__all__r  r"   r#   r$   r%   Úboolr&   r~   r'   r(   r)   r*   ÚcompilerÕ   r+   r,   r-   r.   Z	RE_ASSERTr/   rå   r0   rç   r1   r2   r3   r4   r5   r6   r7   r  r8   ZCAFFE2_TRIEr"  ZPYTORCH_TRIEr  Úobjectr  Úmappingr  rð   ÚitemsÚsrcri   ÚdstZ	meta_dataZ
API_CAFFE2rW   ZAPI_SPECIALr   ZAPI_PYTORCHr  r9  r8  r:  r;  r<  r=  r9   r:   r;   r<   Z
RE_INCLUDEr=   r>   r?   r   r   r   r   Ú<module>   s\  	
                      û!     ú ú-
 ÷#
 &



H
U








 ÷ 9

.              ñð