a
    8SicZ                     @   s  d dl Z d dlZd dlmZ d dlZd dlm  mZ d dl	Zd dlmZ d dl	m
Z
 d dl	mZ d dl	mZ d dlmZ dd	 Zed
d
ddd Zdd Zed
ddddSddZed
d
dddddTddZdd ZedejjjjdddZedejjjjdddZedejjjjdddZed ejjjjdd!dZ ed"ejjjjdd!dZ!ed#ejjjjdd!dZ"d$d% Z#e#d&ejjjjZ$e#d'ejjjjZ%e#d(ejjjjZ&d)d* Z'e'd+dd,Z(e'd-d.d,Z)e'd/d0d,Z*e'd1dd2Z+e'd3d.d2Z,e'd4d0d2Z-d5d6 Z.dUd7d8Z/d9d: Z0ed
d;d<d= Z1d>d? Z2ed
d
d
dddd
dd	d@dA Z3ed
d
d
dddVdDdEZ4dFdG Z5dHdI Z6dJdK Z7dLdM Z8ed
dNdNdNdOdP Z9G dQdR dRZ:dS )W    N)Sequence)_C)_patch_torch)symbolic_helper)symbolic_opset9)GLOBALSc                 G   s2   t |dkrt| ||S t| ||g|R  S d S Nr   )lenopset9true_divide_div_rounding_mode)gselfotherargs r   W/var/www/html/django/DPS/env/lib/python3.9/site-packages/torch/onnx/symbolic_opset10.pydiv   s    r   vsc                 C   s(   |dkrt | ||S t| |||S d S )Nfloor)_floor_divider
   r   )r   r   r   rounding_moder   r   r   r       s    r   c                 C   s   t |st |r.t| ||}| d|S | d||}| jdtjdtjdd}| d| d||| d||}| jd	||dd
}| d|| d| d||}| jdtjdtjdd}	| d||	}
| d||
|S d S )NFloorDivConstantr   dtypevalue_tXorLessModfmod_iAndNotEqual   SubWhere)r   _is_fpr
   r   optorchtensorint64)r   r   r   outr   zeronegativemod
fixup_maskonefixupr   r   r   r   (   s    " r   inonec                 C   s   t j| ||||dS )N)	decendingr0   )r   _sort_helper)r   r   dimr9   r0   r   r   r   sort<   s    r<   c              	   C   s   t j| ||||||dS )N)largestsortedr0   )r   _topk_helper)r   r   kr;   r=   r>   r0   r   r   r   topkA   s    rA   c              	      s<   t ddddddt dddddd fdd}|S )NTFr   isr7   c                    s   |s|}||d ||d}t |dhkrD||d< r| jd|fddi|\}}	| jd|ddd t D d	d t D d
\}
}tj| |dd t D ddd}t| |	|}	||	fS | jd|fddi|}|S d S )N   )kernel_shape_ipads_i	strides_iceil_mode_ir(   dilations_iMaxPooloutputsc                 S   s   g | ]}d qS r(   r   .0_r   r   r   
<listcomp>i       z2_max_pool.<locals>.symbolic_fn.<locals>.<listcomp>c                 S   s   g | ]}d qS rK   r   rL   r   r   r   rO   j   rP   )rJ   rD   rF   c                 S   s   g | ]}d | qS )rC   r   )rM   r7   r   r   r   rO   p   rP   r   )axesstartsends)setr,   ranger   _slice_helperr
   sub)r   inputkernel_sizestridepaddingdilation	ceil_modekwargsrindicesrN   flattened_indicesr   ndimsreturn_indicestuple_fnr   r   symbolic_fnI   s:    

z_max_pool.<locals>.symbolic_fn)r   quantized_args
parse_args)namere   rc   rd   rf   r   rb   r   	_max_poolH   s    /rj   
max_pool1dr(   F)rd   
max_pool2drC   
max_pool3d   max_pool1d_with_indicesTmax_pool2d_with_indicesmax_pool3d_with_indicesc              
      s^   t dddddddt dddddddd
tjtt tt tt ttd fdd	}|S )NTFr   rB   r7   r8   )rX   rY   rZ   r[   r]   count_include_padc           	   	      sr   |s|}t |||| }|rLtj| d|d| d dddd}dt| }| jd	||||d |d
}|S )NPad)r   r   rC   constant           )rE   mode_svalue_fopset_beforer   AveragePool)rD   rF   rE   rG   )r   _avgpool_helperr
   op_with_optional_float_castr	   r,   )	r   rX   rY   rZ   r[   r]   rr   divisor_overrideoutputri   re   r   r   rf      s2    
	z_avg_pool.<locals>.symbolic_fn)N)r   rg   rh   r   Valuer   int)ri   re   rf   r   r   r   	_avg_pool   s    	 $r   
avg_pool1d
avg_pool2d
avg_pool3dc                    s"   t ddd fdd}|S )NTFc                    s`   t | |\}}t  t |}|r6t dS |d u rNt | || }| jd||dS )Nzalign_corners == TrueResizerw   )r   _get_interpolate_attributes_interpolate_warning_maybe_get_scalar_unimplemented_interpolate_size_to_scalesr,   )r   rX   output_sizer   scalesalign_cornersr;   interpolate_moderi   r   r   rf      s    

z!_interpolate.<locals>.symbolic_fn)r   rg   )ri   r;   r   rf   r   r   r   _interpolate   s    r   upsample_nearest1dnearestupsample_nearest2d   upsample_nearest3d   upsample_linear1dlinearupsample_bilinear2dupsample_trilinear3dc           	      C   s*   t | |||||\}}| jd|||dS )Nr   r   )r    _interpolate_get_scales_and_moder,   )	r   rX   sizescale_factormoder   recompute_scale_factor	antialiasr   r   r   r   __interpolate   s    r   c                 C   s`  |rTt | |dg}t | |dg}t|trB| jdt|d}t | |dg}nt|t|kshJ t|t|ks|J |d u st|t|ksJ t|dkr|d dkr|d dkr|d u st|dkr|d dkr|S | jdt|d}| jdt|d}| jdt|d}|d u r8| d||||S | jdt|d}| d|||||S )Nr   r   r   r(       Slice)r   _unsqueeze_helper
isinstancer   r,   r-   r.   r	   )r   rX   rQ   rR   rS   stepsdynamic_slicer   r   r   _slice   s:    






r   c              	   G   sv  t |dkr|\}}}}n$t |dkr6|\}}}d}ntd|  dko\|  dk}|  dko||  dk}|  dk}	|  dk}
t|d}|s|	rt|ts|s|
rt|ts|  dkrd	}|r| j	d
t
dd}|r\| j	d
t
dd}nB|r$dn
t|dg}|r<dn
t|dg}t|dg}d}tj| |||||g|dS )Nr   rn   r   zUnknown aten::slice signaturezprim::ConstantNoneTypezonnx::Constantr7   Tr   r   r   F)rQ   rR   rS   r   r   )r	   NotImplementedErrornodekindtyper   
_parse_argr   r   r,   r-   r.   rV   )r   r   r   r;   startendstepis_start_noneis_end_noneis_start_onnx_constis_end_onnx_constr   r   r   r   slice   s\    
r   rB   c              	   C   s4   t j| ||dgt| dgt| dgt| dS )Nl )rQ   rR   rS   r   )r   rV   r	   )r   rX   dimsr   r   r   flip.  s    r   c                 C   s   | j d||ddS )Nr"   r(   r#   )r,   )r   rX   r   r   r   r   fmod:  s    r   c
                 C   s  |rt jrtdS |	d ur,|	dkr,tdtd t|d}
|
d ur|r^|
d }|}n8|
}|| jdt	
tjgdg}| jdg|R d	di}g }t|D ]8}t| t| |t	
dt	
|dg}t| t| |t	
dt	
|d dg}| jdt	
dgd}| d
||||}| d||}t|sn| d
||||}t| |dg}| d||}|dkrtj| |dgdd}n4|dkr| jd|dgdd}n| jd|dgdd}t| |dg}|| q| jdg|R d	di}|d d d fS tdS d S )Nz7embedding_bag with scale_grad_by_freq for training moder   zembedding_bag with padding_idxzExport of embedding_bag with dynamic input/offsets shape is not supported in opset 10. Please use opset 11 or higher to export model for dynamic input shape.'r(   r   r   Concataxis_ir   GatherMul)axes_i
keepdims_i
ReduceMean	ReduceMaxziembedding_bag with unknown shape of offsets for opset 10 is not supported. please use opset 11 or higher.)r   training_moder   _onnx_unsupportedRuntimeErrorwarningswarn_get_tensor_dim_sizer,   r-   r.   sysmaxsizerU   r   r
   select_is_none_reducesum_helperappend)r   embedding_matrixr`   offsetsscale_grad_by_freqr   sparseper_sample_weightsinclude_last_offsetpadding_idxZoffsets_dim_0Z
offset_lenZoffsets_extendedlist_r7   start_end_axes_indices_row
embeddingsper_sample_weights_rowr   r   r   r   embedding_bag>  st    





r      c              	   C   s   ||fdkrt dddd ||fdvr8td||t |}|d u rZt dddd | j}|d	kr| jd
|tj	j
d}n| jd
|tj	jd}| d| d|||||S )N)r   r   fake_quantize_per_tensor_affine
      z=Quantize range (0, 127) not supported, requires opset 13 Clip))r      )r   r   zSFor (quant_min, quant_max), ONNX allows only (0, 255) and (-128, 127). Got ({}, {})z Non-constant scale not supportedr   Castto_iDequantizeLinearQuantizeLinear)r    _onnx_opset_unsupported_detailedr   formatr   floatdatar,   _C_onnxTensorProtoDataTypeUINT8INT8)r   inputsscale
zero_point	quant_min	quant_maxr   r   r   r     s>    

r   c                 C   s   |  dt| |dS )NIsInfF)r,   r
   Z_cast_Doubler   rX   r   r   r   isinf  s    r   c                 C   s8   ddl m}m} t| |}t| |}|| || ||S )Nr   )__not___or_)torch.onnx.symbolic_opset9r   r   r   r
   isnan)r   rX   r   r   Zinf_nodeZnan_noder   r   r   isfinite  s    
r   c                 C   sH   t |dd}| jd|t j| d}| jd|tjjd}t | |||S )Nr7   r   r   r   )r   
_get_constr,   scalar_type_to_onnxr   r   FLOATquantize_helper)r   rX   r   r   r   r   r   r   quantize_per_tensor  s    r  c                 C   s   t | |d S r   r   dequantize_helperr   r   r   r   
dequantize  s    r  fc                 C   s0  t |s|S t j|   }|d u r,d}t| |}| d|| jdtj	|g|dd|}t
|}|d u rv|j}t| t| |t| || jdtdgd}	| d|	| jdtj	|g|dd|}
|d u r|j}t| t| |
t| |
| jdtdgd}| d|| jdtj	|g|dd|
S )Nru   r*   r   r   r   r   )r   r+   pytorch_name_to_typer   
scalarTyper
   r   r,   r-   r.   finfomaxlogical_andr   gt
LongTensorminlt)r   rX   nanposinfneginfinput_dtypeZnan_condZ
nan_resultr  Zposinf_condZnan_posinf_resultZneginf_condr   r   r   
nan_to_num  sR    
	
r  c                   @   s   e Zd ZdZdZedd Zedd Zedd Zed	d
 Z	edd Z
edd Zedd ZeeddddejeejejejdddZdS )	Quantizedz
    https://github.com/pytorch/pytorch/wiki/PyTorch-ONNX-exporter#quantized-model-export

    Support starts from opset 10 because `DequantizeLinear` and `QuantizeLinear` were introduced in opset version 10.
    	quantizedc                 C   sl   t | |\}}}}t | |\}	}
}}t | |||
}t | |\}}}}t| ||	|}t | |||S N)r   r  requantize_bias_helperr
   r   r  )r   q_inputq_weightbiasop_scaleop_zero_pointrX   input_scalerN   weightweight_scaleq_biasr   r   r   r   r     s    zQuantized.linearc                 C   sF   t | |\}}}}t | |\}}}}t| ||}t | |||S r  )r   r  r
   addr  r   xyr  r   rN   r   r   r   r   r%    s    zQuantized.addc                 C   sR   t | |\}}}}t | |\}}}}t| ||}t| |}t | |||S r  )r   r  r
   r%  relur  r&  r   r   r   add_relu&  s
    zQuantized.add_reluc                 C   sF   t | |\}}}}t | |\}}}}t| ||}t | |||S r  )r   r  r
   mulr  r&  r   r   r   r+  0  s    zQuantized.mulc                 C   s0   t | |\}}}}t| |}t | |||S r  )r   r  r
   	hardswishr  )r   r'  r  r   rN   r   r   r   r   r,  9  s    zQuantized.hardswishc
              
   C   s   t | |\}
}}}t | |\}}}}t | |||}t | |\}}}}t| |
||||||}t| |}t | |||	S r  )r   r  r  r
   conv2dr)  r  r   r  r  r  rZ   r[   r\   groupsr  r   rX   r!  rN   r"  r#  r$  r   r   r   r   conv2d_reluA  s    zQuantized.conv2d_reluc
              
   C   st   t | |\}
}}}t | |\}}}}t | |||}t | |\}}}}t| |
||||||}t | |||	S r  )r   r  r  r
   r-  r  r.  r   r   r   r-  \  s    zQuantized.conv2dr   r7   )q_inputsr;   r  r   returnc                    sD   t |} fdd|D } jdg|R d|i}t  |||S )Nc                    s   g | ]}t  |d  qS rz   r  )rM   rX   r   r   r   rO     s   z!Quantized.cat.<locals>.<listcomp>r   r   )r   _unpack_listr,   r  )r   r1  r;   r  r   Zunpacked_inputsZdequantizedconcatenatedr   r3  r   catv  s    	

zQuantized.catN)__name__
__module____qualname____doc__domainstaticmethodr   r%  r*  r+  r,  r0  r-  r   rh   r   r   r   r6  r   r   r   r   r    s0   


	



r  )N)N)NF)r   r   );r   r   typingr   r-   Ztorch._C._onnxr   _onnxr   
torch.onnxr   r   r   r
   torch.onnx._globalsr   r   rh   r   r   r<   rA   rj   nnmodulesutils_singlerk   _pairrl   _triplerm   ro   rp   rq   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r  r  r   r   r   r   <module>   s   

5*	
/


T &	
3