a
    Sic                     @   s&  d Z ddlZddlZddlZddlZddlZddlZej	e
d ejddszddlmZ ddlmZ ddlmZ n ddlmZ ddlmZ d	d
 ZG dd deZeddddZG dd deZedejG dd dejZejfddZedG dd deZ G dd de Z!dS )zPython TF-Lite interpreter.    Ntflite_runtimeinterpreter)&_pywrap_tensorflow_interpreter_wrapper)metrics)	tf_export)metrics_portablec                  O   s   ~ ~dd S )Nc                 S   s   | S N )xr	   r	   ^/var/www/html/django/DPS/env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py<lambda>&       z_tf_export.<locals>.<lambda>r	   )r
   kwargsr	   r	   r   
_tf_export$   s    r   c                   @   s*   e Zd ZdZd	ddZdd Zdd ZdS )
Delegatea  Python wrapper class to manage TfLiteDelegate objects.

  The shared library is expected to have two functions:
    TfLiteDelegate* tflite_plugin_create_delegate(
        char**, char**, size_t, void (*report_error)(const char *))
    void tflite_plugin_destroy_delegate(TfLiteDelegate*)

  The first one creates a delegate object. It may return NULL to indicate an
  error (with a suitable error message reported by calling report_error()).
  The second one destroys delegate object and must be called for every
  created delegate object. Passing NULL as argument value is allowed, i.e.

    tflite_plugin_destroy_delegate(tflite_plugin_create_delegate(...))

  always works.
  Nc                 C   s  t  dkrtdtj|| _ttjttjtj	t
dtjg| jj_tj| jj_|pbi }tjt|  }tjt|  }t| D ]0\}\}}t|d||< t|d||< qG dd dt}| }	t
dtj|	j}
| j||t||
| _| jdu rt|	jdS )a  Loads delegate from the shared library.

    Args:
      library: Shared library name.
      options: Dictionary of options that are required to load the delegate. All
        keys and values in the dictionary should be serializable. Consult the
        documentation of the specific delegate for required and legal options.
        (default None)

    Raises:
      RuntimeError: This is raised if the Python implementation is not CPython.
    CPythonz_Delegates are currently only supported into CPythondue to missing immediate reference counting.Nutf-8c                   @   s   e Zd Zdd Zdd ZdS )z.Delegate.__init__.<locals>.ErrorMessageCapturec                 S   s
   d| _ d S )N )messageselfr	   r	   r   __init__d   s    z7Delegate.__init__.<locals>.ErrorMessageCapture.__init__c                 S   s&   |  j t|tr|n|d7  _ d S )Nr   )r   
isinstancestrdecode)r   r
   r	   r	   r   reportg   s    z5Delegate.__init__.<locals>.ErrorMessageCapture.reportN)__name__
__module____qualname__r   r   r	   r	   r	   r   ErrorMessageCaptureb   s   r   )platformpython_implementationRuntimeErrorctypespydllLoadLibrary_libraryPOINTERc_char_pc_int	CFUNCTYPEZtflite_plugin_create_delegateargtypesc_void_prestypelen	enumerateitemsr   encodeobjectr   _delegate_ptr
ValueErrorr   )r   libraryoptionsZoptions_keysZoptions_valuesidxkeyvaluer   captureZerror_capturer_cbr	   r	   r   r   >   s,    

zDelegate.__init__c                 C   s0   | j d ur,tjg| j j_| j | j d | _ d S r   )r&   r#   r,   Ztflite_plugin_destroy_delegater+   r3   r   r	   r	   r   __del__r   s    
zDelegate.__del__c                 C   s   | j S )zReturns the native TfLiteDelegate pointer.

    It is not safe to copy this pointer because it needs to be freed.

    Returns:
      TfLiteDelegate *
    )r3   r   r	   r	   r   _get_native_delegate_pointerz   s    z%Delegate._get_native_delegate_pointer)N)r   r   r   __doc__r   r;   r<   r	   r	   r	   r   r   ,   s   
4r   zlite.experimental.load_delegatec              
   C   sL   zt | |}W n8 tyF } z td| t|W Y d}~n
d}~0 0 |S )ac  Returns loaded Delegate object.

  Example usage:

  ```
  import tensorflow as tf

  try:
    delegate = tf.lite.experimental.load_delegate('delegate.so')
  except ValueError:
    // Fallback to CPU

  if delegate:
    interpreter = tf.lite.Interpreter(
        model_path='model.tflite',
        experimental_delegates=[delegate])
  else:
    interpreter = tf.lite.Interpreter(model_path='model.tflite')
  ```

  This is typically used to leverage EdgeTPU for running TensorFlow Lite models.
  For more information see: https://coral.ai/docs/edgetpu/tflite-python/

  Args:
    library: Name of shared library containing the
      [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates).
    options: Dictionary of options that are required to load the delegate. All
      keys and values in the dictionary should be convertible to str. Consult
      the documentation of the specific delegate for required and legal options.
      (default None)

  Returns:
    Delegate object.

  Raises:
    ValueError: Delegate failed to load.
    RuntimeError: If delegate loading is used on unsupported platform.
  z"Failed to load delegate from {}
{}N)r   r4   formatr   )r5   r6   delegateer	   r	   r   load_delegate   s    (rA   c                   @   s2   e Zd ZdZdddZdd Zdd Zd	d
 ZdS )SignatureRunneraJ  SignatureRunner class for running TFLite models using SignatureDef.

  This class should be instantiated through TFLite Interpreter only using
  get_signature_runner method on Interpreter.
  Example,
  signature = interpreter.get_signature_runner("my_signature")
  result = signature(input_1=my_input_1, input_2=my_input_2)
  print(result["my_output"])
  print(result["my_second_output"])
  All names used are this specific SignatureDef names.

  Notes:
    No other function on this object or on the interpreter provided should be
    called while this object call has not finished.
  Nc                 C   s~   |st d|st d|| _|j| _|| _| }||vrDt d|| | _| jd  | _| jd | _| j	| j| _
dS )zConstructor.

    Args:
      interpreter: Interpreter object that is already initialized with the
        requested model.
      signature_key: SignatureDef key to be used.
    zNone interpreter provided.zNone signature_key provided.zInvalid signature_key provided.outputsinputsN)r4   _interpreter_interpreter_wrapperZ_signature_key_get_full_signature_listZ_signature_defr0   _outputs_inputsZGetSubgraphIndexFromSignature_subgraph_index)r   r   signature_keyZsignature_defsr	   r	   r   r      s"    
zSignatureRunner.__init__c                 K   s   t |t | jkr,tdt | jt |f | D ]F\}}|| jvrRtd| | j| j| tj|jtj	dd| j
 q4| j| j
 | D ] \}}| j| j| || j
 q| j| j
 i }| jD ]\}}| j|| j
||< q|S )aq  Runs the SignatureDef given the provided inputs in arguments.

    Args:
      **kwargs: key,value for inputs to the model. Key is the SignatureDef input
        name. Value is numpy array with the value.

    Returns:
      dictionary of the results from the model invoke.
      Key in the dictionary is SignatureDef output name.
      Value is the result Tensor.
    zXInvalid number of inputs provided for running a SignatureDef, expected %s vs provided %sz(Invalid Input name (%s) for SignatureDefdtypeF)r.   rI   r4   r0   rF   ResizeInputTensornparrayshapeint32rJ   AllocateTensors	SetTensorInvokerH   	GetTensor)r   r   
input_namer9   resultoutput_nameoutput_indexr	   r	   r   __call__   s6    

zSignatureRunner.__call__c                 C   s,   i }| j  D ]\}}| j|||< q|S )a   Gets input tensor details.

    Returns:
      A dictionary from input name to tensor details where each item is a
      dictionary with details about an input tensor. Each dictionary contains
      the following fields that describe the tensor:

      + `name`: The tensor name.
      + `index`: The tensor index in the interpreter.
      + `shape`: The shape of the tensor.
      + `shape_signature`: Same as `shape` for models with known/fixed shapes.
        If any dimension sizes are unknown, they are indicated with `-1`.
      + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
      + `quantization`: Deprecated, use `quantization_parameters`. This field
        only works for per-tensor quantization, whereas
        `quantization_parameters` works in all cases.
      + `quantization_parameters`: A dictionary of parameters used to quantize
        the tensor:
        ~ `scales`: List of scales (one if per-tensor quantization).
        ~ `zero_points`: List of zero_points (one if per-tensor quantization).
        ~ `quantized_dimension`: Specifies the dimension of per-axis
        quantization, in the case of multiple scales/zero_points.
      + `sparsity_parameters`: A dictionary of parameters used to encode a
        sparse tensor. This is empty if the tensor is dense.
    )rI   r0   rE   _get_tensor_details)r   rX   rW   tensor_indexr	   r	   r   get_input_details  s    z!SignatureRunner.get_input_detailsc                 C   s(   i }| j D ]\}}| j|||< q
|S )a  Gets output tensor details.

    Returns:
      A dictionary from input name to tensor details where each item is a
      dictionary with details about an output tensor. The dictionary contains
      the same fields as described for `get_input_details()`.
    )rH   rE   r\   )r   rX   rY   r]   r	   r	   r   get_output_details'  s    z"SignatureRunner.get_output_details)NN)r   r   r   r=   r   r[   r^   r_   r	   r	   r	   r   rB      s
   
(rB   z lite.experimental.OpResolverTypec                   @   s    e Zd ZdZdZdZdZdZdS )OpResolverTypea	  Different types of op resolvers for Tensorflow Lite.

  * `AUTO`: Indicates the op resolver that is chosen by default in TfLite
     Python, which is the "BUILTIN" as described below.
  * `BUILTIN`: Indicates the op resolver for built-in ops with optimized kernel
    implementation.
  * `BUILTIN_REF`: Indicates the op resolver for built-in ops with reference
    kernel implementation. It's generally used for testing and debugging.
  * `BUILTIN_WITHOUT_DEFAULT_DELEGATES`: Indicates the op resolver for
    built-in ops with optimized kernel implementation, but it will disable
    the application of default TfLite delegates (like the XNNPACK delegate) to
    the model graph. Generally this should not be used unless there are issues
    with the default configuration.
  r            N)r   r   r   r=   AUTOBUILTINBUILTIN_REF!BUILTIN_WITHOUT_DEFAULT_DELEGATESr	   r	   r	   r   r`   5  s
   r`   c                 C   s$   t jdt jdt jdt jdi| dS )z-Get a integer identifier for the op resolver.ra   rb   rc   N)r`   rd   re   rf   rg   get)Zop_resolver_typer	   r	   r   _get_op_resolver_idT  s    ri   zlite.Interpreterc                   @   s   e Zd ZdZddddejdfddZdd Zdd	 Zd
d Z	dd Z
dd Zdd Zdd Zdd Zdd Zdd Zd/ddZdd Zdd Zd d! Zd0d"d#Zd1d%d&Zd'd( Zd)d* Zd+d, Zd-d. ZdS )2Interpretera6  Interpreter interface for running TensorFlow Lite models.

  Models obtained from `TfLiteConverter` can be run in Python with
  `Interpreter`.

  As an example, lets generate a simple Keras model and convert it to TFLite
  (`TfLiteConverter` also supports other input formats with `from_saved_model`
  and `from_concrete_function`)

  >>> x = np.array([[1.], [2.]])
  >>> y = np.array([[2.], [4.]])
  >>> model = tf.keras.models.Sequential([
  ...           tf.keras.layers.Dropout(0.2),
  ...           tf.keras.layers.Dense(units=1, input_shape=[1])
  ...         ])
  >>> model.compile(optimizer='sgd', loss='mean_squared_error')
  >>> model.fit(x, y, epochs=1)
  >>> converter = tf.lite.TFLiteConverter.from_keras_model(model)
  >>> tflite_model = converter.convert()

  `tflite_model` can be saved to a file and loaded later, or directly into the
  `Interpreter`. Since TensorFlow Lite pre-plans tensor allocations to optimize
  inference, the user needs to call `allocate_tensors()` before any inference.

  >>> interpreter = tf.lite.Interpreter(model_content=tflite_model)
  >>> interpreter.allocate_tensors()  # Needed before execution!

  Sample execution:

  >>> output = interpreter.get_output_details()[0]  # Model has single output.
  >>> input = interpreter.get_input_details()[0]  # Model has single input.
  >>> input_data = tf.constant(1., shape=[1, 1])
  >>> interpreter.set_tensor(input['index'], input_data)
  >>> interpreter.invoke()
  >>> interpreter.get_tensor(output['index']).shape
  (1, 1)

  Use `get_signature_runner()` for a more user-friendly inference API.
  NFc                 C   s  t | dsg | _|}|r2|tjks,|tjkr2tj}t|}|du rPtd||r|sdd | jD }	dd | jD }
t	
|||	|
|| _| jstd|n^|r|sdd | jD }	d	d | jD }
|| _t	|||	|
|| _n|s|std
ntd|dur<t|tstd|dk r0td| j| g | _|rl|| _| jD ]}| j|  qT|  | _t | _| j  dS )a  Constructor.

    Args:
      model_path: Path to TF-Lite Flatbuffer file.
      model_content: Content of model.
      experimental_delegates: Experimental. Subject to change. List of
        [TfLiteDelegate](https://www.tensorflow.org/lite/performance/delegates)
          objects returned by lite.load_delegate().
      num_threads: Sets the number of threads used by the interpreter and
        available to CPU kernels. If not set, the interpreter will use an
        implementation-dependent default number of threads. Currently, only a
        subset of kernels, such as conv, support multi-threading. num_threads
        should be >= -1. Setting num_threads to 0 has the effect to disable
        multithreading, which is equivalent to setting num_threads to 1. If set
        to the value -1, the number of threads used will be
        implementation-defined and platform-dependent.
      experimental_op_resolver_type: The op resolver used by the interpreter. It
        must be an instance of OpResolverType. By default, we use the built-in
        op resolver which corresponds to tflite::ops::builtin::BuiltinOpResolver
        in C++.
      experimental_preserve_all_tensors: If true, then intermediate tensors used
        during computation are preserved for inspection, and if the passed op
        resolver type is AUTO or BUILTIN, the type will be changed to
        BUILTIN_WITHOUT_DEFAULT_DELEGATES so that no Tensorflow Lite default
        delegates are applied. If false, getting intermediate tensors could
        result in undefined values or None, especially when the graph is
        successfully modified by the Tensorflow Lite default delegate.

    Raises:
      ValueError: If the interpreter was unable to create.
    _custom_op_registerersNz+Unrecognized passed in op resolver type: {}c                 S   s   g | ]}t |tr|qS r	   r   r   .0r
   r	   r	   r   
<listcomp>  s   z(Interpreter.__init__.<locals>.<listcomp>c                 S   s   g | ]}t |ts|qS r	   rl   rm   r	   r	   r   ro     s   zFailed to open {}c                 S   s   g | ]}t |tr|qS r	   rl   rm   r	   r	   r   ro     s   c                 S   s   g | ]}t |ts|qS r	   rl   rm   r	   r	   r   ro     s   z2`model_path` or `model_content` must be specified.z3Can't both provide `model_path` and `model_content`z!type of num_threads should be intra   znum_threads should >= 1)hasattrrk   r`   rd   re   rg   ri   r4   r>   rF   ZCreateWrapperFromFilerE   Z_model_contentZCreateWrapperFromBufferr   intZSetNumThreads
_delegatesZModifyGraphWithDelegater<   get_signature_list_signature_defsr   TFLiteMetrics_metrics%increase_counter_interpreter_creation)r   
model_pathZmodel_contentZexperimental_delegatesnum_threadsZexperimental_op_resolver_typeZ!experimental_preserve_all_tensorsZactual_resolver_typeZop_resolver_idcustom_op_registerers_by_namecustom_op_registerers_by_funcr?   r	   r	   r   r     sx    &






zInterpreter.__init__c                 C   s   d | _ d | _d S r   )rE   rr   r   r	   r	   r   r;     s    zInterpreter.__del__c                 C   s   |    | j S r   )_ensure_saferE   rS   r   r	   r	   r   allocate_tensors  s    zInterpreter.allocate_tensorsc                 C   s   t | jdkS )zReturns true if there exist no numpy array buffers.

    This means it is safe to run tflite calls that may destroy internally
    allocated memory. This works, because in the wrapper.cc we have made
    the numpy base be the self._interpreter.
    rb   )sysgetrefcountrE   r   r	   r	   r   _safe_to_run  s    
zInterpreter._safe_to_runc                 C   s   |   stddS )aB  Makes sure no numpy arrays pointing to internal buffers are active.

    This should be called from any function that will call a function on
    _interpreter that may reallocate memory e.g. invoke(), ...

    Raises:
      RuntimeError: If there exist numpy objects pointing to internal memory
        then we throw.
    zThere is at least 1 reference to internal data
      in the interpreter in the form of a numpy array or slice. Be sure to
      only hold the function returned from tensor() if you are using raw
      data access.N)r   r"   r   r	   r	   r   r|     s    
zInterpreter._ensure_safec                 C   s>   t |}| j|}| j|}| j|}||||d}|S )a"  Gets a dictionary with arrays of ids for tensors involved with an op.

    Args:
      op_index: Operation/node index of node to query.

    Returns:
      a dictionary containing the index, op name, and arrays with lists of the
      indices for the inputs and outputs of the op/node.
    )indexop_namerD   rC   )rq   rE   ZNodeNameZ
NodeInputsZNodeOutputs)r   Zop_indexr   	op_inputs
op_outputsdetailsr	   r	   r   _get_op_details   s    
zInterpreter._get_op_detailsc           
   
   C   s   t |}| j|}| j|}| j|}| j|}| j|}| j|}| j|}|sht	d|||||||d |d |d d|d}	|	S )a  Gets tensor details.

    Args:
      tensor_index: Tensor index of tensor to query.

    Returns:
      A dictionary containing the following fields of the tensor:
        'name': The tensor name.
        'index': The tensor index in the interpreter.
        'shape': The shape of the tensor.
        'quantization': Deprecated, use 'quantization_parameters'. This field
            only works for per-tensor quantization, whereas
            'quantization_parameters' works in all cases.
        'quantization_parameters': The parameters used to quantize the tensor:
          'scales': List of scales (one if per-tensor quantization)
          'zero_points': List of zero_points (one if per-tensor quantization)
          'quantized_dimension': Specifies the dimension of per-axis
              quantization, in the case of multiple scales/zero_points.

    Raises:
      ValueError: If tensor_index is invalid.
    zCould not get tensor detailsr   ra   rb   )scaleszero_pointsZquantized_dimension)namer   rQ   Zshape_signaturerM   quantizationZquantization_parametersZsparsity_parameters)
rq   rE   Z
TensorNameZ
TensorSizeZTensorSizeSignature
TensorTypeZTensorQuantizationZTensorQuantizationParametersZTensorSparsityParametersr4   )
r   r]   tensor_nametensor_sizeZtensor_size_signaturetensor_typeZtensor_quantizationZtensor_quantization_paramsZtensor_sparsity_paramsr   r	   r	   r   r\   8  s6    zInterpreter._get_tensor_detailsc                    s    fddt  j D S )zGets op details for every node.

    Returns:
      A list of dictionaries containing arrays with lists of tensor ids for
      tensors involved in the op.
    c                    s   g | ]}  |qS r	   )r   )rn   r7   r   r	   r   ro   v  s   z0Interpreter._get_ops_details.<locals>.<listcomp>)rangerE   ZNumNodesr   r	   r   r   _get_ops_detailso  s    
zInterpreter._get_ops_detailsc              	   C   sD   g }t | j D ],}z|| | W q ty<   Y q0 q|S )a#  Gets tensor details for every tensor with valid tensor details.

    Tensors where required information about the tensor is not found are not
    added to the list. This includes temporary tensors without a name.

    Returns:
      A list of dictionaries containing tensor information.
    )r   rE   Z
NumTensorsappendr\   r4   )r   Ztensor_detailsr7   r	   r	   r   get_tensor_detailsz  s    	zInterpreter.get_tensor_detailsc                    s    fdd j  D S )a  Gets model input tensor details.

    Returns:
      A list in which each item is a dictionary with details about
      an input tensor. Each dictionary contains the following fields
      that describe the tensor:

      + `name`: The tensor name.
      + `index`: The tensor index in the interpreter.
      + `shape`: The shape of the tensor.
      + `shape_signature`: Same as `shape` for models with known/fixed shapes.
        If any dimension sizes are unknown, they are indicated with `-1`.
      + `dtype`: The numpy data type (such as `np.int32` or `np.uint8`).
      + `quantization`: Deprecated, use `quantization_parameters`. This field
        only works for per-tensor quantization, whereas
        `quantization_parameters` works in all cases.
      + `quantization_parameters`: A dictionary of parameters used to quantize
        the tensor:
        ~ `scales`: List of scales (one if per-tensor quantization).
        ~ `zero_points`: List of zero_points (one if per-tensor quantization).
        ~ `quantized_dimension`: Specifies the dimension of per-axis
        quantization, in the case of multiple scales/zero_points.
      + `sparsity_parameters`: A dictionary of parameters used to encode a
        sparse tensor. This is empty if the tensor is dense.
    c                    s   g | ]}  |qS r	   r\   rn   ir   r	   r   ro     s   z1Interpreter.get_input_details.<locals>.<listcomp>)rE   ZInputIndicesr   r	   r   r   r^     s    
zInterpreter.get_input_detailsc                 C   s   | j || dS )a  Sets the value of the input tensor.

    Note this copies data in `value`.

    If you want to avoid copying, you can use the `tensor()` function to get a
    numpy buffer pointing to the input buffer in the tflite interpreter.

    Args:
      tensor_index: Tensor index of tensor to set. This value can be gotten from
        the 'index' field in get_input_details.
      value: Value of tensor to set.

    Raises:
      ValueError: If the interpreter could not set the tensor.
    N)rE   rT   )r   r]   r9   r	   r	   r   
set_tensor  s    zInterpreter.set_tensorc                 C   s,   |    tj|tjd}| j||| dS )a  Resizes an input tensor.

    Args:
      input_index: Tensor index of input to set. This value can be gotten from
        the 'index' field in get_input_details.
      tensor_size: The tensor_shape to resize the input to.
      strict: Only unknown dimensions can be resized when `strict` is True.
        Unknown dimensions are indicated as `-1` in the `shape_signature`
        attribute of a given tensor. (default False)

    Raises:
      ValueError: If the interpreter could not resize the input tensor.

    Usage:
    ```
    interpreter = Interpreter(model_content=tflite_model)
    interpreter.resize_tensor_input(0, [num_test_images, 224, 224, 3])
    interpreter.allocate_tensors()
    interpreter.set_tensor(0, test_images)
    interpreter.invoke()
    ```
    rL   N)r|   rO   rP   rR   rE   rN   )r   input_indexr   strictr	   r	   r   resize_tensor_input  s    zInterpreter.resize_tensor_inputc                    s    fdd j  D S )zGets model output tensor details.

    Returns:
      A list in which each item is a dictionary with details about
      an output tensor. The dictionary contains the same fields as
      described for `get_input_details()`.
    c                    s   g | ]}  |qS r	   r   r   r   r	   r   ro     s   z2Interpreter.get_output_details.<locals>.<listcomp>)rE   ZOutputIndicesr   r	   r   r   r_     s    
zInterpreter.get_output_detailsc                 C   sH   | j  }| D ]0\}}t|d  |d< t|d  |d< q|S )a  Gets list of SignatureDefs in the model.

    Example,
    ```
    signatures = interpreter.get_signature_list()
    print(signatures)

    # {
    #   'add': {'inputs': ['x', 'y'], 'outputs': ['output_0']}
    # }

    Then using the names in the signature list you can get a callable from
    get_signature_runner().
    ```

    Returns:
      A list of SignatureDef details in a dictionary structure.
      It is keyed on the SignatureDef method name, and the value holds
      dictionary of inputs and outputs.
    rD   rC   )rE   GetSignatureDefsr0   listkeys)r   Zfull_signature_defs_signature_defr	   r	   r   rs     s
    
zInterpreter.get_signature_listc                 C   s
   | j  S )a  Gets list of SignatureDefs in the model.

    Example,
    ```
    signatures = interpreter._get_full_signature_list()
    print(signatures)

    # {
    #   'add': {'inputs': {'x': 1, 'y': 0}, 'outputs': {'output_0': 4}}
    # }

    Then using the names in the signature list you can get a callable from
    get_signature_runner().
    ```

    Returns:
      A list of SignatureDef details in a dictionary structure.
      It is keyed on the SignatureDef method name, and the value holds
      dictionary of inputs and outputs.
    )rE   r   r   r	   r	   r   rG     s    z$Interpreter._get_full_signature_listc                 C   sF   |du r:t | jdkr,tdt | jntt| j}t| |dS )a  Gets callable for inference of specific SignatureDef.

    Example usage,
    ```
    interpreter = tf.lite.Interpreter(model_content=tflite_model)
    interpreter.allocate_tensors()
    fn = interpreter.get_signature_runner('div_with_remainder')
    output = fn(x=np.array([3]), y=np.array([2]))
    print(output)
    # {
    #   'quotient': array([1.], dtype=float32)
    #   'remainder': array([1.], dtype=float32)
    # }
    ```

    None can be passed for signature_key if the model has a single Signature
    only.

    All names used are this specific SignatureDef names.


    Args:
      signature_key: Signature key for the SignatureDef, it can be None if and
        only if the model has a single SignatureDef. Default value is None.

    Returns:
      This returns a callable that can run inference for SignatureDef defined
      by argument 'signature_key'.
      The callable will take key arguments corresponding to the arguments of the
      SignatureDef, that should have numpy values.
      The callable will returns dictionary that maps from output names to numpy
      values of the computed results.

    Raises:
      ValueError: If passed signature_key is invalid.
    Nra   zwSignatureDef signature_key is None and model has {0} Signatures. None is only allowed when the model has 1 SignatureDef)r   rK   )r.   rt   r4   r>   nextiterrB   )r   rK   r	   r	   r   get_signature_runner  s    %z Interpreter.get_signature_runnerr   c                 C   s   | j ||S )a  Gets the value of the output tensor (get a copy).

    If you wish to avoid the copy, use `tensor()`. This function cannot be used
    to read intermediate results.

    Args:
      tensor_index: Tensor index of tensor to get. This value can be gotten from
        the 'index' field in get_output_details.
      subgraph_index: Index of the subgraph to fetch the tensor. Default value
        is 0, which means to fetch from the primary subgraph.

    Returns:
      a numpy array.
    )rE   rV   )r   r]   subgraph_indexr	   r	   r   
get_tensorE  s    zInterpreter.get_tensorc                    s    fddS )al  Returns function that gives a numpy view of the current tensor buffer.

    This allows reading and writing to this tensors w/o copies. This more
    closely mirrors the C++ Interpreter class interface's tensor() member, hence
    the name. Be careful to not hold these output references through calls
    to `allocate_tensors()` and `invoke()`. This function cannot be used to read
    intermediate results.

    Usage:

    ```
    interpreter.allocate_tensors()
    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])
    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])
    for i in range(10):
      input().fill(3.)
      interpreter.invoke()
      print("inference %s" % output())
    ```

    Notice how this function avoids making a numpy array directly. This is
    because it is important to not hold actual numpy views to the data longer
    than necessary. If you do, then the interpreter can no longer be invoked,
    because it is possible the interpreter would resize and invalidate the
    referenced tensors. The NumPy API doesn't allow any mutability of the
    the underlying buffers.

    WRONG:

    ```
    input = interpreter.tensor(interpreter.get_input_details()[0]["index"])()
    output = interpreter.tensor(interpreter.get_output_details()[0]["index"])()
    interpreter.allocate_tensors()  # This will throw RuntimeError
    for i in range(10):
      input.fill(3.)
      interpreter.invoke()  # this will throw RuntimeError since input,output
    ```

    Args:
      tensor_index: Tensor index of tensor to get. This value can be gotten from
        the 'index' field in get_output_details.

    Returns:
      A function that can return a new numpy array pointing to the internal
      TFLite tensor state at any point. It is safe to hold the function forever,
      but it is not safe to hold the numpy array forever.
    c                      s    j  j S r   )rE   tensorr	   r   r]   r	   r   r     r   z$Interpreter.tensor.<locals>.<lambda>r	   r   r	   r   r   r   V  s    0zInterpreter.tensorc                 C   s   |    | j  dS )a  Invoke the interpreter.

    Be sure to set the input sizes, allocate tensors and fill values before
    calling this. Also, note that this function releases the GIL so heavy
    computation can be done in the background while the Python interpreter
    continues. No other function on this object should be called while the
    invoke() call has not finished.

    Raises:
      ValueError: When the underlying interpreter fails raise ValueError.
    N)r|   rE   rU   r   r	   r	   r   invoke  s    zInterpreter.invokec                 C   s
   | j  S r   )rE   ZResetVariableTensorsr   r	   r	   r   reset_all_variables  s    zInterpreter.reset_all_variablesc                 C   s
   | j  S )a}  Returns a pointer to the underlying tflite::Interpreter instance.

    This allows extending tflite.Interpreter's functionality in a custom C++
    function. Consider how that may work in a custom pybind wrapper:

      m.def("SomeNewFeature", ([](py::object handle) {
        auto* interpreter =
          reinterpret_cast<tflite::Interpreter*>(handle.cast<intptr_t>());
        ...
      }))

    and corresponding Python call:

      SomeNewFeature(interpreter.native_handle())

    Note: This approach is fragile. Users must guarantee the C++ extension build
    is consistent with the tflite.Interpreter's underlying C++ build.
    )rE   r   r   r	   r	   r   _native_handle  s    zInterpreter._native_handle)F)N)r   )r   r   r   r=   r`   rd   r   r;   r}   r   r|   r   r\   r   r   r^   r   r   r_   rs   rG   r   r   r   r   r   r   r	   r	   r	   r   rj   b  s8   )
j	7

/
2rj   c                       s"   e Zd ZdZd fdd	Z  ZS )InterpreterWithCustomOpsar  Interpreter interface for TensorFlow Lite Models that accepts custom ops.

  The interface provided by this class is experimental and therefore not exposed
  as part of the public API.

  Wraps the tf.lite.Interpreter class and adds the ability to load custom ops
  by providing the names of functions that take a pointer to a BuiltinOpResolver
  and add a custom op.
  Nc                    s$   |pg | _ tt| jf i | dS )a  Constructor.

    Args:
      custom_op_registerers: List of str (symbol names) or functions that take a
        pointer to a MutableOpResolver and register a custom op. When passing
        functions, use a pybind function that takes a uintptr_t that can be
        recast as a pointer to a MutableOpResolver.
      **kwargs: Additional arguments passed to Interpreter.

    Raises:
      ValueError: If the interpreter was unable to create.
    N)rk   superr   r   )r   Zcustom_op_registerersr   	__class__r	   r   r     s    
z!InterpreterWithCustomOps.__init__)N)r   r   r   r=   r   __classcell__r	   r	   r   r   r     s   
r   )N)"r=   r#   enumosr    r~   numpyrO   pathsplitext__file__endswithjoinZ*tensorflow.lite.python.interpreter_wrapperr   rF   tensorflow.lite.python.metricsr    tensorflow.python.util.tf_exportr   r   r   r   r2   r   rA   rB   uniqueEnumr`   rd   ri   rj   r   r	   r	   r	   r   <module>   s>   Y/     R