a
    Sic(C                    @   s  d Z ddlZddlZddlZddlZddlZddlZddlZddl	Z
ddlm  mZ ddlmZ ddlmZ ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ dd ZG dd deej dZ!G dd de!Z"dd Z#dd Z$dd Z%G dd de!Z&dZ'da(dd Z)G dd de!Z*G d d! d!e!Z+dd#d$Z,dd&d'Z-dd(d)Z.d*d+ Z/dd-d.Z0d/d0 Z1d1d2 Z2d3d4 Z3dd5d6Z4d7d8 Z5dd:d;Z6d<d= Z7dd>d?Z8d@dA Z9dBdC Z:ddDdEZ;ddFdGZ<ddHdIZ=dJdK Z>ddLdMZ?ddOdPZ@ddQdRZAdSdT ZBddUdVZCdWdX ZDddYdZZEd[d\ ZFd]d^ ZGd_d` ZHddadbZIdcdd ZJdedf ZKdgdh ZLdidj ZMdkdl ZNdmdn ZOdodp ZPdqdr ZQdsdt ZRddvdwZSG dxdy dyZTdzd{ ZUd|d} ZVd~d ZWdddZXG dd dZYdS )zTraining-related utilities.    N)backend)	callbacks)losses)metrics)
data_utils)generic_utils)losses_utils)
tf_inspect)
tf_loggingc                 C   s"   t | tjjtjjjtjjjjfS )zCReturns true if 'tensor' is a CompositeTensor or a CT Value object.)	
isinstancetf__internal__CompositeTensorcompatv1SparseTensorValueraggedRaggedTensorValuetensor r   Z/var/www/html/django/DPS/env/lib/python3.9/site-packages/keras/engine/training_utils_v1.pyis_composite_or_composite_value)   s    
r   c                   @   sF   e Zd ZdZdddZejdd ZejdddZejd	d
 Z	dS )
Aggregatora  Abstract base class used to aggregate batch-level outputs of a loop.

    Attributes:
      use_steps: Whether the loop is using `step` or `batch_size`.
      num_samples: Total number of samples: `batch_size * num_batches`.
      steps: Total number of steps.
      batch_size: Batch size. It is used for validation checks between inputs
        and outputs.
      results: What to return at the end of the aggregation loop.
    Nc                 C   s"   || _ || _|| _|| _g | _d S N)	use_stepsnum_samplessteps
batch_sizeresults)selfr   r   r   r   r   r   r   __init__C   s
    zAggregator.__init__c                 C   s   t ddS )zCreates the initial results from the first batch outputs.

        Args:
          batch_outs: A list of batch-level outputs.
        "Must be implemented in subclasses.NNotImplementedErrorr    
batch_outsr   r   r   createL   s    zAggregator.createc                 C   s   t ddS )aO  Aggregates batch-level results into total results.

        Args:
          batch_outs: A list of batch-level outputs.
          batch_start: The start index of this batch. Always `None` if
            `use_steps` is `True`.
          batch_end: The end index of this batch. Always `None` if `use_steps`
            is `True`.
        r"   Nr#   r    r&   batch_start	batch_endr   r   r   	aggregateU   s    zAggregator.aggregatec                 C   s   t ddS )z*Prepares the total results to be returned.r"   Nr#   r    r   r   r   finalizeb   s    zAggregator.finalize)NNN)NN)
__name__
__module____qualname____doc__r!   abcabstractmethodr'   r+   r-   r   r   r   r   r   7   s    
	
r   )	metaclassc                       s<   e Zd ZdZd fdd	Zdd ZdddZd	d
 Z  ZS )MetricsAggregatora?  Aggregator that calculates loss and metrics info.

    Attributes:
      use_steps: Whether the loop is using `step` or `batch_size`.
      num_samples: Total number of samples: `batch_size*num_batches`.
      steps: Total number of steps, ie number of times to iterate over a dataset
        to cover all samples.
    Nc                    s   t  j|||d d d S )Nr   r   r   r   )superr!   )r    r   r   r   	__class__r   r   r!   r   s    zMetricsAggregator.__init__c                 C   s   dgt | | _d S )N        )lenr   r%   r   r   r   r'   z   s    zMetricsAggregator.createc                 C   sV   | j r| jd  |d 7  < n| jd  |d ||  7  < |dd  | jdd < d S )Nr      )r   r   r(   r   r   r   r+   }   s    zMetricsAggregator.aggregatec                 C   s,   | j std| j d  | jp"| j  < d S )NzEmpty training data.r   )r   
ValueErrorr   r   r,   r   r   r   r-      s    zMetricsAggregator.finalize)NN)NN	r.   r/   r0   r1   r!   r'   r+   r-   __classcell__r   r   r8   r   r5   h   s
   	
	r5   c                 C   s  t | jt |jkr,td| || j|jf | jdd |jdd krptd| || jdd |jdd f | jd }| jd }| j}|jD ]4}|d  |7  < t||d }tj||gdd}qtj| j|jfdd}t	| j}|d |d< t
|}tjjj|||dS )z#Append sparse tensor value objects.zlUnable to concatenate %s and %s. The inner dense shapes do not have the same number of dimensions (%s vs %s)r<   Nz`Unable to concatenate %s and %s. The inner dense shapes do not match inner dimensions (%s vs %s)r   axis)indicesvaluesdense_shape)r;   rD   RuntimeErrorrB   maxnpappendconcatenaterC   listtupler   r   r   r   )target	to_appendbase_dim0_valuemax_dim0_valuenew_indicesindex
new_valuesnew_dense_shaper   r   r   _append_sparse_tensor_value   s<    



rT   c                 C   s   t | jt |jkr$td| |f | jdd |jdd krPtd| |f |jdd | jd  }t| j|}t| jtj	j
jjrt| j|j}ntj| j|jfdd}tj	j
j||S )z#Append ragged tensor value objects.Unable to concatenate %s and %sr<   Nr   r@   )r;   shaperE   
row_splitsrG   rH   r   rC   r   r   r   r   r   _append_ragged_tensor_valuerI   )rL   rM   adjusted_row_splitsnew_row_splitsrR   r   r   r   rY      s     

rY   c                 C   s   t | t |ur(tdt | t |f t| tjrJtjjj| |gddS t| tjrhtj	| |gddS t| tjjj
rt| |S t| tjjjjrt| |S tdt |  dS )a  Helper function to append composite tensors to each other in the 0 axis.

    In order to support batching within a fit/evaluate/predict call, we need
    to be able to aggregate within a CompositeTensor. Unfortunately, the CT
    API currently does not make this easy - especially in V1 mode, where we're
    working with CompositeTensor Value objects that have no connection with the
    CompositeTensors that created them.

    Args:
      target: CompositeTensor or CompositeTensor value object that will be
        appended to.
      to_append: CompositeTensor or CompositeTensor value object to append to.
        'target'.

    Returns:
      A CompositeTensor or CompositeTensor value object.

    Raises:
      RuntimeError: if concatenation is not possible.
    rU   r   )	sp_inputsrA   r@   z/Attempted to concatenate unsupported object %s.N)typerE   r   r   SparseTensorr   r   sparse_concatRaggedTensorconcatr   rT   r   r   rY   )rL   rM   r   r   r   _append_composite_tensor   s    


rb   c                       s:   e Zd ZdZ fddZdd ZdddZd	d
 Z  ZS )ConcatAggregatorzCombine tensor-likes which cannot be merged on the fly.

    This class expects to aggregate a single tensor-like rather than a nested
    structure of tensor-likes.
    c                    s   d | _ t jdd d |d d S )NTr6   )	compositer7   r!   )r    r   r8   r   r   r!     s    zConcatAggregator.__init__c                 C   s   t || _d S r   )r   rd   )r    batch_elementr   r   r   r'     s    zConcatAggregator.createNc                 C   sJ   | j r:| j |jd k r:td|j| j f|jdd   | j| d S )Nr   uMismatch between expected batch size and model output batch size. Output shape = {}, expected output shape = shape {}r<   )r   rW   r=   formatr   rH   )r    re   r)   r*   r   r   r   r+     s    zConcatAggregator.aggregatec                 C   sh   t | jdkr| jd | _nH| jrR| jd }| jdd  D ]}t||}q:|| _ntj| jdd| _d S )Nr<   r   r@   )r;   r   rd   rb   rG   rI   )r    r   rr   r   r   r-   -  s    
zConcatAggregator.finalize)NNr>   r   r   r8   r   rc     s
   
rc      c                   C   s$   t du r tjta tt j t S )zShared threadpool for copying arrays.

    Pool instantiation takes ~ 2ms, so a singleton pool is used rather than
    creating a pool per SliceAggregator.

    Returns:
      The global copy threadpool.
    N)
_COPY_POOLmultiprocessingpool
ThreadPool_COPY_THREADSatexitregistercloser   r   r   r   get_copy_poolA  s    
rr   c                       sH   e Zd ZdZdZdZ fddZdd Zdd	 Zd
d Z	dd Z
  ZS )SliceAggregatora  Combine arrays where the final size is known.

    This class expects to aggregate a single tensor-like rather than a nested
    structure of tensor-likes.

    NumPy copies are an operation that threads handle quite well because all of
    the heavy lifting is in c and does not need the GIL. Moreover, we can
    perform lock-free writes to the same buffer in multiple threads because the
    nature of result aggregation guarantees that either the indices are disjoint
    or the aggregator will throw an exception in finalize. Moreover, because
    aggregation is performed on the slowest varying dimension, assignments for a
    given batch will write to contiguous blocks of memory, further minimizing
    contention.

    There is, however, some scheduling and context switching overhead which will
    offset the gains from pipelining the slice assignment. Below a given
    threshold it is faster to simply assign in the main thread rather than
    enqueue the assignment in a side thread. The exact threshold will vary from
    system to system, but the time is not very sensitive to the exact transition
    so a value of 2 ** 14 was chosen which should be reasonable on most systems.
    i @  i,  c                    s,   g | _ t | _g | _t jd|d |d d S )NFr6   )_async_copiesrr   _pool_errorsr7   r!   )r    r   r   r8   r   r   r!   k  s    zSliceAggregator.__init__c                 C   s0   | j f|jdd   }|j}tj||d| _d S )Nr<   )rW   dtype)r   rW   rw   rG   emptyr   )r    re   rW   rw   r   r   r   r'   v  s    zSliceAggregator.createc                 C   s   | j r| j d || | jkrN| j|jd krDtd|j| jj|| _d S t|j}|| jk rt|| j||< n.t	
 }| jj| j||||fd | j| d S )Nr   rf   )args)rv   r   rW   r=   rg   r   rG   prod_BINARY_SIZE_THRESHOLD	threadingEventru   apply_async_slice_assignrt   rH   )r    re   r)   r*   num_elementsis_finishedr   r   r   r+   ~  s(    



zSliceAggregator.aggregatec              
   C   s`   zPz|| j ||< W n0 tyD } z| j| W Y d}~n
d}~0 0 W |  n
|  0 dS )z,Legacy utility method to slice input arrays.N)r   	Exceptionrv   rH   set)r    re   r)   r*   r   er   r   r   r     s
    $zSliceAggregator._slice_assignc                 C   sT   t   }| jD ]0}td| jt   |  g}||stdq| jrP| jd d S )Nr:   z'Timed out waiting for copy to complete.r   )timert   rF   _MAX_COPY_SECONDSwaitr=   rv   )r    
start_timer   timeoutr   r   r   r-     s    


zSliceAggregator.finalize)r.   r/   r0   r1   r{   r   r!   r'   r+   r   r-   r?   r   r   r8   r   rs   Q  s    rs   c                   @   s.   e Zd ZdZdZdd Zd	ddZdd ZdS )
OutputsAggregatorz%Aggregator that concatenates outputs.Nc                 C   s   t jjdd || _t jj| j|}|D ]r}t|rL| jt	| j
 nBt|tjr| j| jrnt	| j
nt| j| j
 ntd|| jd | q,d S )Nc                 S   s
   t |  S r   )r   xr   r   r   <lambda>      z*OutputsAggregator.create.<locals>.<lambda>z-Attempted to aggregate unsupported object {}.rV   )r   r   nestget_traverse_shallow_structure
_structureflatten_up_tor   r   rH   rc   r   r   rG   ndarrayr   rs   r   rE   rg   r'   )r    r&   re   r   r   r   r'     s*    zOutputsAggregator.createc                 C   s:   t jj| j|}t|| jD ]\}}|||| qd S r   )r   r   r   r   r   zipr   r+   )r    r&   r)   r*   re   resultr   r   r   r+     s
    zOutputsAggregator.aggregatec                 C   s>   | j D ]}|  qdd | j D | _ tj| j| j | _ d S )Nc                 S   s   g | ]
}|j qS r   )r   .0ir   r   r   
<listcomp>  r   z.OutputsAggregator.finalize.<locals>.<listcomp>)r   r-   r   r   pack_sequence_asr   )r    r   r   r   r   r-     s    

zOutputsAggregator.finalize)NN)r.   r/   r0   r1   r   r'   r+   r-   r   r   r   r   r     s
   %
r   Tc                 C   s4   |r"t | dd}|r&|dd }nd}tj||dS )zGet Progbar.metrics_namesNr<   )stateful_metrics)getattrcbksProgbarLogger)model
count_modeinclude_metricsstateful_metric_namesr   r   r   get_progbar  s    r   r   c                 C   sT   |dur |dur t d| d t| ||r0dS t| d drPt| d jd S dS )a!  Determine the number of samples provided for training and evaluation.

    The number of samples is not defined when running with `steps`,
    in which case the number of samples is set to `None`.

    Args:
        ins: List of tensors to be fed to the Keras function.
        batch_size: Integer batch size or `None` if not defined.
        steps: Total number of steps (batches of samples) before declaring
          `_predict_loop` finished. Ignored with the default value of `None`.
        steps_name: The public API's parameter name for `steps`.

    Raises:
        ValueError: when `steps` is `None` and the attribute `ins.shape`
        does not exist. Also raises ValueError when `steps` is not `None`
        and `batch_size` is not `None` because they are mutually
        exclusive.

    Returns:
        When steps is `None`, returns the number of samples to be
        processed based on the size of the first dimension of the
        first input numpy array. When steps is not `None` and
        `batch_size` is `None`, returns `None`.
    NzIf z' is set, the `batch_size` must be None.r   rW   )r=   check_steps_argumenthasattrintrW   )insr   r   
steps_namer   r   r   check_num_samples  s    
r   c                 C   s   | du rdS t | r| S t| tr0td| | jdurt| jdkr|du s\t|dkrt| rztj	j
j| dd} nt| d} | S )zCExpand data of shape (x,) to (x, 1), unless len(expected_shape)==1.Nz7Expected an array data type but received an integer: {}r<   r@   )r   r   r   r=   rg   rW   r;   r   	is_tensorr   r   expand_dimsrG   )r   expected_shaper   r   r   standardize_single_array$  s(    


r   c                 C   s    t | tjjjr| jS | jS dS )z1Returns the shape of the passed composite tensor.N)r   r   r   r   r   rD   rW   r   r   r   r   get_composite_shape=  s    r    c              
      s  zt  }W n ty"   d}Y n0 |sL|rHt tsHtd| d  g S  du rjdd tt |D S t trz fdd|D  W nB ty } z*td|jd  d	 t| W Y d}~n
d}~0 0 nt t	t
frDt d t	t
frd
d  D  n>t |dkr4t d ttfr4t g ndd  D  n jjdkrX jn   g |durdd t |D  ndd  D  t  t |kr rt d drtd| d tt | d d t| d tt   d t dd  d nt |dkrNtd| d tt | d t dd  n^t  dkrt d dstd| d t dd  d nt |dkrt g |r
tt |D ]H}|| durt | r | j}|sqt
| }	n2t | r,t | }t
| }	n
 | j}	|| }
t |	t |
krtd| d ||  d tt |
 d t|	 |s|	dd }	|
dd }
t|	|
D ]X\}}||kr|dur|durtd| d ||  d t|
 d  t|	 qq S )!ac  Normalizes inputs and targets provided by users.

    Users may pass data as a list of arrays, dictionary of arrays,
    or as a single array. We normalize this to an ordered list of
    arrays (same order as `names`), while checking that the provided
    arrays have shapes that match the network's expectations.

    Args:
        data: User-provided input data (polymorphic).
        names: List of expected array names.
        shapes: Optional list of expected array shapes.
        check_batch_axis: Boolean; whether to check that the batch axis of the
          arrays matches the expected value found in `shapes`.
        exception_prefix: String prefix used for exception formatting.

    Returns:
        List of standardized input arrays (one array per model input).

    Raises:
        ValueError: in case of improperly formatted user-provided data.
    NzError when checking model z: expected no data, but got:c                 S   s   g | ]}d qS r   r   r   _r   r   r   r   m  r   z*standardize_input_data.<locals>.<listcomp>c                    s.   g | ]&} | j jd kr" | jn | qS 	DataFramer9   r.   rC   r   r   datar   r   r   q  s   zNo data provided for "r   z". Need data for each key in: c                 S   s   g | ]}t |qS r   )rG   asarray)r   dr   r   r   r   ~  r   r<   c                 S   s"   g | ]}|j jd kr|jn|qS r   r   r   r   r   r   r     s   r   c                 S   s   g | ]\}}t ||qS r   r   )r   r   rW   r   r   r   r     s   c                 S   s   g | ]}t |qS r   r   r   r   r   r   r     r   rW   zr: the list of Numpy arrays that you are passing to your model is not the size the model expected. Expected to see z array(s), zfor inputs z' but instead got the following list of z	 arrays:    z...zQ: you are passing a list as input to your model, but the model expects a list of z0 Numpy arrays instead. The list you passed was: zE: data should be a Numpy array, or list/dict of Numpy arrays. Found: zError when checking z: expected z	 to have z& dimensions, but got array with shape z to have shape z but got array with shape )r;   	TypeErrorr   dictr=   rangeKeyErrorry   strrJ   rK   floatr   rG   r   r9   r.   rC   r   r   r   r   rW   as_listr   r   )r   namesshapescheck_batch_axisexception_prefixdata_lenr   r   tensorshape
data_shaperW   dimref_dimr   r   r   standardize_input_dataF  s2   




"

	

	




r   c                 C   sD  | du s"t | ttfr0t| dkr0dd |D S t|dkrt | ttfrZt| dkrZ| S t | tr~|d | v r~| |d  gS | gS t | ttfrt| t|krtd| d tt|  d tt| d	 | d
 | S t | tjj	r t
|| | g }|D ]}|| | q|S td| d | d t|  dS )a  Maps `sample_weight` or `class_weight` to model outputs.

    Args:
        x_weight: User-provided `sample_weight` or `class_weight` argument.
        output_names: List of output names (strings) in the model.
        weight_type: A string used purely for exception printing.

    Returns:
        A list of `sample_weight` or `class_weight` where there are exactly
            one element per model output.

    Raises:
        ValueError: In case of invalid user-provided argument.
    Nr   c                 S   s   g | ]}d qS r   r   r   r   r   r   r     r   z7standardize_sample_or_class_weights.<locals>.<listcomp>r<   z
Provided `z` was a list of z elements, but the model has z" outputs. You should provide one `z`array per model output.z$The model has multiple outputs, so `z/` should be either a list or a dict. Provided `z` type not understood: )r   rJ   rK   r;   r   r=   r   collectionsr2   Mappingr   check_for_unexpected_keysrH   getr   )x_weightoutput_namesweight_type	x_weightsnamer   r   r   #standardize_sample_or_class_weights  sf    



r   c                 C   s   t | |dS )Nclass_weightr   )r   r   r   r   r   standardize_class_weights%  s    r   c                 C   s   t | |dS )Nsample_weightr   )r   r   r   r   r   standardize_sample_weights+  s    r   c                    sH  dd   fdd}|| }||}||}t |dkrRtdtdd | D  t |dkrxtd	td
d |D  |r|rt|d t|d krtdtt|d  d tt|d  d t |dkrtdtdd |D  |rD|rDt|d t|d krDtdtt|d  d tt|d  d dS )a  Does user input validation for numpy arrays.

    Args:
        inputs: list of Numpy arrays of inputs.
        targets: list of Numpy arrays of targets.
        weights: list of Numpy arrays of sample weights.

    Raises:
        ValueError: in case of incorrectly formatted data.
    c                 S   s   t | pt| S r   )r   r   r   r   r   r   r   is_tensor_or_composite_tensor=  s    z:check_array_lengths.<locals>.is_tensor_or_composite_tensorc                    s&   | d u ri S t  fdd| D S d S )Nc                    s&   g | ]}|d ur |s|j d qS )Nr   rW   r   yr   r   r   r   G  s   z?check_array_lengths.<locals>.set_of_lengths.<locals>.<listcomp>)r   r   r   r   r   set_of_lengths@  s    
z+check_array_lengths.<locals>.set_of_lengthsr<   zOAll input arrays (x) should have the same number of samples. Got array shapes: c                 S   s   g | ]
}|j qS r   r   r   r   r   r   r   U  r   z'check_array_lengths.<locals>.<listcomp>zPAll target arrays (y) should have the same number of samples. Got array shapes: c                 S   s   g | ]
}|j qS r   r   r   r   r   r   r   [  r   r   zLInput arrays should have the same number of samples as target arrays. Found z input samples and z target samples.zSAll sample_weight arrays should have the same number of samples. Got array shapes: c                 S   s   g | ]
}|j qS r   r   )r   wr   r   r   r   h  r   zRSample_weight arrays should have the same number of samples as target arrays. Got N)r;   r=   r   rJ   )inputstargetsweightsr   set_xset_yset_wr   r   r   check_array_lengths1  s^     &r   c                 C   s4  t jt jt jh}t jt jt jf}t| ||D ] \}}}|du s,|du s,t	|rTq,t 
|r|jd dkrtdt|j d t|t j}t||s|r,|j|v r,t|jdd |dd D ]j\}	}
|
dur|	|
kr|j}|du r|r|jnt|}|j}tdt|j d t| d | d	 qq,dS )
a  Does validation on the compatibility of targets and loss functions.

    This helps prevent users from using loss functions incorrectly. This check
    is purely for UX purposes.

    Args:
        targets: list of Numpy arrays of targets.
        loss_fns: list of loss functions.
        output_shapes: list of shapes of model outputs.

    Raises:
        ValueError: if a loss function or target array
            is incompatible with an output.
    NrV   r<   z(You are passing a target array of shape a   while using as loss `categorical_crossentropy`. `categorical_crossentropy` expects targets to be binary matrices (1s and 0s) of shape (samples, classes). If your targets are integer classes, you can convert them to the expected format via:
```
from keras.utils import to_categorical
y_binary = to_categorical(y_int)
```

Alternatively, you can use the loss function `sparse_categorical_crossentropy` instead, which does expect integer targets.zA target array with shape z# was passed for an output of shape z while using as loss `zB`. This loss expects targets to have the same shape as the output.)r   mean_squared_errorbinary_crossentropycategorical_crossentropyMeanSquaredErrorBinaryCrossentropyCategoricalCrossentropyr   r   r   is_categorical_crossentropyrW   r=   r   r   LossFunctionWrapperfnr   r]   r.   )r   loss_fnsoutput_shapeskey_loss_fnskey_loss_classesr   lossrW   is_loss_wrapper
target_dimout_dim	loss_name	loss_typer   r   r   #check_loss_and_target_compatibilityu  s^    

$
r   Fc                 C   s  | sdd |D S t | trtdd | D }|rrt| t|krbtdtt| d t|  dd | D }n6t|dkrg }|D ]}|d	d | D  qn| g}nXt | tjj	rt
d
| | g }|D ] }	t
| |	g }
||
 qntdt|  g }t|D ]x\}} t }| D ]V}t||}t||| || d}||_t |tjsntj||d}d|_|||< q"|| q|S )aq  Maps metric names and functions to model outputs.

    Args:
        metrics: a list or a list of lists or a dict of metric functions.
        output_names: a list of the names (strings) of model outputs.
        output_shapes: a list of the shapes (strings) of model outputs.
        loss_fns: a list of the loss functions corresponding to the model
          outputs.
        from_serialized: whether the model the metrics are being sourced from is
          being initialized from a serialized format.
        is_weighted: Boolean indicating whether the given metrics are weighted.

    Returns:
        A list (one entry per model output) of dicts.
        For instance, if the model has 2 outputs, and for the first output
        we want to compute "binary_accuracy" and "binary_crossentropy",
        and just "binary_accuracy" for the second output,
        the list would look like: `[{
            'acc': binary_accuracy(),
            'ce': binary_crossentropy(),
          }, {
            'acc': binary_accuracy(),
          }]`

    Raises:
        TypeError: if an incorrect type is passed for the `metrics` argument.
    c                 S   s   g | ]}i qS r   r   r   r   r   r   r     r   z2collect_per_output_metric_info.<locals>.<listcomp>c                 s   s   | ]}t |tV  qd S r   )r   rJ   r   mr   r   r   	<genexpr>  r   z1collect_per_output_metric_info.<locals>.<genexpr>zdWhen passing a list of lists as `metrics`, it should have one entry per model output. The model has z! outputs, but you passed metrics=c                 S   s   g | ]}t |qS r   )r   to_listr   r   r   r   r     r   r<   c                 S   s   g | ]}t |qS r   )metrics_moduleclone_metricr   r   r   r   r     r   r   zQType of `metrics` argument not understood. Expected a list or dictionary, found: )output_shapeloss_fn)r   F)r   rJ   anyr;   r=   r   rH   r   r2   r   r   r   r  r   r   	enumerateOrderedDictget_metric_nameget_metric_function_from_serializedr  MetricMeanMetricWrapper)r   r   r   r   from_serializedis_weightedany_sub_listnested_metricsr   r   output_metricsper_output_metricsr   metrics_dictmetricmetric_name	metric_fnr   r   r   collect_per_output_metric_info  sj    #

	
r  c                 C   s^   t t| | }| || d }| d||  } | ||f} tj|  |  } t| |S )a5  Shuffles an array in a batch-wise fashion.

    Useful for shuffling HDF5 arrays
    (where one cannot access arbitrary indices).

    Args:
        index_array: array of indices to be shuffled.
        batch_size: integer.

    Returns:
        The `index_array` array, shuffled in a batch-wise fashion.
    N)r   r;   reshaperG   randomshuffleflattenrH   )index_arrayr   batch_count
last_batchr   r   r   batch_shuffle#  s    r!  c                    s:  t |tr|d }|dur|dkr|dkr:tdt| tjdk r^tdtj d |durt|jd	krtd
t|j d n(|durt|jdkrtd|j||dur>t|jtjkrtdt|j d ttj t|s>jd|j	 |jkr>td
t|j d tj d d}t  t
rtjd	krftdtrZtt  }t fdd|D }tt|d }tj|dd< |||< tjjtj d	kotd dkfddfdd}tjj||}tj|d t|t }|durtt|t }n}tjd	krjd dkrtjdd}n"jd dkrtjd }t  fdd|D }t|t|krt!|}	t!  }
td|	|
  |dur|dur|| S |dur(|S |dur6|S dS )aY  Performs sample weight validation and standardization.

    Everything gets normalized to a single sample-wise (or timestep-wise)
    weight array. If both `sample_weight` and `class_weight` are provided,
    the weights are multiplied.

    Args:
        y: Numpy array or Tensor of model targets to be weighted.
        sample_weight: User-provided `sample_weight` argument.
        class_weight: User-provided `class_weight` argument.
        sample_weight_mode: One of `None` or `"temporal"`. `"temporal"`
          indicated that we expect 2D weight data that will be applied to the
          last 2 dimensions of the targets (i.e. we are weighting timesteps, not
          samples).

    Returns:
        A numpy array of target weights, one entry per sample to weight.

    Raises:
        ValueError: In case of invalid user-provided arguments.
    r   N
samplewisetemporalz9"sample_weight_mode should be None or "temporal". Found:    z4Found a sample_weight array for an input with shape z. Timestep-wise sample weighting (use of sample_weight_mode="temporal") is restricted to outputs that are at least 3D, i.e. that have a time dimension.   z'Found a sample_weight array with shape z[. In order to use timestep-wise sample weighting, you should pass a 2D sample_weight array.r<   a  Found a sample_weight array with shape {}. In order to use timestep-wise sample weights, you should specify sample_weight_mode="temporal" in compile(); founssd "{}" instead. If you just mean to use sample-wise weights, make sure your sample_weight array is 1D.z Found a sample_weight with shapez8.Expected sample_weight with rank less than or equal to z for an input with shape z$. sample_weight cannot be broadcast.z8`class_weight` not supported for 3+ dimensional targets.c                    s   g | ]} | qS r   r   r   r   r   r   r     r   z'standardize_weights.<locals>.<listcomp>c                      s   t j ddS )Nr<   r@   )r   argmaxr   r   r   r   r     r   z%standardize_weights.<locals>.<lambda>c                      s   t t dt jS )N)rV   )r   castr   r  int64r   r(  r   r   r     r   zxInvalid classes or class weights detected. NaN values indicate that an appropriate class weight could not be determined.r@   c                    s   g | ]}| v r | qS r   r   )r   clsr&  r   r   r     r   zp`class_weight` must contain all classes in the data. The classes %s exist in the data but not in `class_weight`.)"r   rK   r=   r   r;   rW   rg   r   r   ndimr   rG   arraysortedkeyszerosrF   nanr   
smart_condr   r   r   r   gather	debuggingcheck_numericsr)  floatxconvert_to_tensorr'  r  r   r   )r   r   r   sample_weight_modeclass_sample_weightr/  rC   weight_vector	y_classesexisting_classesexisting_class_weightr   )r   r   r   standardize_weights;  s    


	
$




r>  c                 C   s   t  rdS t| S )NF)r   executing_eagerlyhas_tensorslsr   r   r   has_symbolic_tensors  s    rC  c                 C   sX   t | ttfr tdd | D S t | tr@tdd |  D S t| oVt | tj S )z&Returns true if `ls` contains tensors.c                 s   s&   | ]}t |ot|t j V  qd S r   r   r   r   r`   r   vr   r   r   r    s   zhas_tensors.<locals>.<genexpr>c                 s   s*   | ]"\}}t |o t|t j V  qd S r   rD  )r   r   rF  r   r   r   r    s   )	r   rJ   rK   r  r   itemsr   r   r`   rA  r   r   r   r@    s    
r@  c                 C   s   t jj r:t| tr| S t| } t| dr4| j	S | j
S |rBdnd}| dv rj| dv r\d}q| dv rd}n"t| }t|dr|j	}n|j
}|| }|S d	S )
zReturns the name corresponding to the given metric input.

    Args:
      metric: Metric function name or reference.
      weighted: Boolean indicating if the given metric is weighted.

    Returns:
        The metric name.
    r   	weighted_r   accuracyacccrossentropycerJ  rK  rK  )rL  rM  rM  N)r   r   tf2enabledr   r   r  r   r   r   r.   )r  weightedmetric_name_prefixsuffixr  r  r   r   r   r
    s"    




r
  c                 C   s   | dvrt | S t|tjp4t|tjo4|jtjk}t|tjpXt|tjoX|jtj	k}| dv r|d dksr|rxt j
S |rt jS t jS |d dks|rt j	S |rt jS t jS dS )a<  Returns the metric function corresponding to the given metric input.

    Args:
        metric: Metric function name or reference.
        output_shape: The shape of the output that this metric will be
          calculated for.
        loss_fn: The loss function used.

    Returns:
        The metric function.
    rI  rN  rV   r<   N)r  r   r   r   SparseCategoricalCrossentropyr   r   sparse_categorical_crossentropyr   r   binary_accuracysparse_categorical_accuracycategorical_accuracyr   )r  r  r  "is_sparse_categorical_crossentropyis_binary_crossentropyr   r   r   r    s.    


r  c                 C   sr   |durPt ||j}|du r$|}n,t j||jd}tj||d\}}}||9 }|durf| |||dS | ||dS )z=Invokes metric function and returns the metric result tensor.Nrw   )r   )r   r)  rw   r   squeeze_or_expand_dimensions)r  y_truey_predr   maskr   r   r   r   call_metric_function:  s    r`  c                 C   s   | du st | tjr| S t| r<t| tjr<td| t | tj	j
rTt| } t| rjt| dsj| S t| }tj||jtjjdS )zBReturns the loss corresponding to the loss input in `compile` API.NzeReceived uninstantiated Loss class: {}
Please call loss classes before passing them to Model.compile.r.   )r   	reduction)r   r   Lossr	   isclass
issubclassr=   rg   r   r2   r   r   callabler   r   r.   r   ReductionV2SUM_OVER_BATCH_SIZE)r   r  r   r   r   get_loss_functionQ  s$    

rh  c                 C   sT   |durt d| |f |dur0t d| |f |durP|dkrPt d| |f dS )a  Validates user input arguments when a dataset iterator is passed.

    Args:
      x: Input data. A `tf.data` dataset or iterator.
      y: Target data. It could be either Numpy array(s) or TensorFlow tensor(s).
        Expected to be `None` when `x` is a dataset iterator.
      sample_weight: An optional sample-weight array passed by the user to
        weight the importance of each sample in `x`. Expected to be `None` when
        `x` is a dataset iterator
      validation_split: Float between 0 and 1. Fraction of the training data to
        be used as validation data. Expected to be `None` when `x` is a dataset
        iterator.

    Raises:
      ValueError: if argument `y` or `sample_weight` or `validation_split` are
          provided by user.
    NzYou passed a dataset or dataset iterator (%s) as input `x` to your model. In that case, you should not specify a target (`y`) argument, since the dataset or dataset iterator generates both input data and target data. Received: %sz`sample_weight` argument is not supported when input `x` is a dataset or a dataset iterator. Instead, youcan provide sample_weight as the third element  of yourdataset, i.e. (inputs, targets, sample_weight). Received: x=%s, sample_weight=%sr:   z`validation_split` argument is not supported when input `x` is a dataset or a dataset iterator. Received: x=%s, validation_split=%fr=   )r   r   r   validation_splitr   r   r   validate_dataset_inputt  s$    rk  r   c                 C   s   t | ttfr:tdd | D s~td| dt| nDt | trX|s~td|n&t | tj	s~t
| s~td||dS )z5Helper function to validate either inputs or targets.c                 s   s$   | ]}t |tjpt|V  qd S r   )r   rG   r   r   r   rE  r   r   r   r    r   z'validate_input_types.<locals>.<genexpr>zVPlease provide as model inputs either a single array or a list of arrays. You passed: =z)You cannot pass a dictionary as model {}.z[Please provide as model inputs either a single array or a list of arrays. You passed: {}={}N)r   rJ   rK   allr=   r   r   rg   rG   r   r   r   )inporig_inp
allow_dict
field_namer   r   r   validate_input_types  s(    
rr  c                 C   s0   | durt d|dur t d|r,t ddS )z2Validates arguments passed when using a generator.Nz`y` argument is not supported when data isa generator or Sequence instance. Instead pass targets as the second element of the generator.z`sample_weight` argument is not supported when data isa generator or Sequence instance. Instead pass sample weights as the third element of the generator.zUIf your data is in the form of a Python generator, you cannot use `validation_split`.ri  )r   r   rj  r   r   r   check_generator_arguments  s    rs  c                    s   t | tjjjjtjjf}| du s<|s<t| s<t | trf| sf|du rb|rLdnd}tdj	||ddS t | tjjjj
tjj
frdS |durtjttf t |  st | trt fdd|  D rtd	 d
S )a/  Validates `steps` argument based on input data's type.

    The cases when `steps` value must be provided are when
      1. input data passed is an iterator.
      2. model was built on top of symbolic tensors, input data is not
         required and is `None`.
      3. input data passed is a symbolic tensor.

    Args:
        input_data: Input data. Can be Numpy array(s) or TensorFlow tensor(s) or
          tf.data.Dataset iterator or `None`.
        steps: Integer or `None`. Total number of steps (batches of samples) to
          execute.
        steps_name: The public API's parameter name for `steps`.

    Returns:
      boolean, True if `steps` argument is required, else False.

    Raises:
        ValueError: if `steps` argument is required for given input data type
          but not provided.
    Nza Dataset iteratorzdata tensorsz\When using {input_type} as input to a model, you should specify the `{steps_name}` argument.)
input_typer   Tc                 3   s   | ]}t | V  qd S r   )r   rE  
list_typesr   r   r    r   z'check_steps_argument.<locals>.<genexpr>zvWhen passing input data as arrays, do not specify `steps_per_epoch`/`steps` argument. Please use `batch_size` instead.F)r   r   r   r   r   IteratorrC  rJ   r=   rg   DatasetrG   r   rK   r   r  rC   loggingwarning)
input_datar   r   is_x_iteratorinput_type_strr   ru  r   r     sF    

r   c                 C   s<   t | tjrt| } |p t }| jjr8tj	| |dS | S )Nr[  )
r   rG   r   r   r7  r   r6  rw   is_floatingr)  r   rw   r   r   r   cast_single_tensor  s    
r  c                 C   sz   t | rt| |d jdS g }t| |D ]J\}}t|tjrHt |}|j|jkrj|	t||jd q*|	| q*|S )aG  Returns target data tensors using correct datatype.

    Checks that each target and output pair are the same datatype. If not, casts
    the target to the output's datatype.

    Args:
      targets: tensor or list of targets.
      outputs: tensor or list of outputs.

    Returns:
      Targets in appropriate datatype.
    r   r[  )
r   r   r  rw   r   r   rG   r   r7  rH   )r   outputsnew_targetsrL   outr   r   r   #cast_if_floating_dtype_and_mismatch  s    

r  c                 C   s   t jtjt|d| S )a  Casts the given data tensors to the default floating point type.

    Casts only if the input is already a floating point type.
    Args:
      x: tensor or list/tuple of tensors.
      dtype: The dtype to which Tensors should be cast.

    Returns:
      Converted input.
    r[  )r   r   map_structure	functoolspartialr  r  r   r   r   cast_if_floating_dtype+  s    r  c                 C   s&   t jdd |j}t jt j| |S )a   Casts the given data tensors to the dtypes of the model inputs.

    Args:
      x: tensor or list/tuple of tensors.
      model: The model.

    Returns:
      Converted input. Each tensor is casted to the corresponding input in
      `model.inputs`.
    c                 S   s   | j S r   r[  )tr   r   r   r   F  r   z,cast_to_model_input_dtypes.<locals>.<lambda>)r   r   r  r   r)  )r   r   input_dtypesr   r   r   cast_to_model_input_dtypes;  s    r  c                 C   s   t |tjjrftd|dd | D  | D ]8}| s*|j|vrTtd|j d q*|	|j|_
q*nt |ttfrt|t| krtdtt|  d tt| d t|| D ]\}}| s||_
qn| D ]}| s||_
qd	S )
a  Prepares sample weight modes for the model.

    Args:
      training_endpoints: List of model _TrainingEndpoints.
      sample_weight_mode: sample weight mode user input passed from compile API.

    Raises:
      ValueError: In case of invalid `sample_weight_mode` input.
    r8  c                 S   s   g | ]
}|j qS r   output_namer   r   r   r   r   r   Y  r   z/prepare_sample_weight_modes.<locals>.<listcomp>zOutput z.missing from `_sample_weight_modes` dictionaryzdWhen passing a list as sample_weight_mode, it should have one entry per model output. The model has z outputs, but you passed z_sample_weight_modes.N)r   r   r2   r   r   r   should_skip_target_weightsr  r=   r   r8  rJ   rK   r;   r   r   )training_endpointsr8  	end_pointmodeendpointr   r   r   prepare_sample_weight_modesJ  sL    



	
r  c                    s   t  tjjr\td | g }|D ]4}| vrBtd| d |t	 
|d q$nzt  trz fdd|D }n\t  tjjrt t|krtdt| tjt	 }n fddtt|D }|S )	a  Converts loss to a list of loss functions.

    Args:
        loss: String (name of objective function), objective function or
          `tf.keras.losses.Loss` instance. See `tf.keras.losses`.
          If the model has multiple
          outputs, you can use a different loss on each output by passing a
          dictionary or a list of losses. The loss value that will be minimized
          by the model will then be the sum of all individual losses.
        output_names: List of model output names.

    Returns:
        A list of loss objective functions.

    Raises:
        ValueError: If loss is a dict with keys not in model output names,
            or if loss is a list with len not equal to model outputs.
    r   zOutput {0} missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to .Nc                    s   g | ]}t  qS r   rh  r   r   r   r   r     r   z*prepare_loss_functions.<locals>.<listcomp>zyWhen passing a list as loss, it should have one entry per model outputs. The model has {} outputs, but you passed loss={}c                    s   g | ]}t  qS r   r  r   r  r   r   r     s   )r   r   r2   r   r   r   ry  rz  rH   rh  r   r   Sequencer;   r=   rg   r   r   r  r   )r   r   loss_functionsr   r   r  r   prepare_loss_functions|  s2    


r  c                 C   s   |du r| D ]
}d|_ qnt|tjjr\td|dd | D  | D ]}||jd|_ qDnjt|t	rt
|t
| krtdtt
|  d t| t|| D ]\}}||_ qntdt| d	 dS )
aw  Converts loss weights to a list of loss weights.

    The result loss weights will be populated on the training endpoint.

    Args:
        training_endpoints: List of model training endpoints.
        loss_weights: Optional list or dictionary specifying scalar coefficients
          (Python floats) to weight the loss contributions of different model
          outputs. The loss value that will be minimized by the model will then
          be the *weighted sum* of all individual losses, weighted by the
          `loss_weights` coefficients. If a list, it is expected to have a 1:1
          mapping to the model's outputs. If a dict, it is expected to map
          output names (strings) to scalar coefficients.

    Raises:
        ValueError: If loss weight is a dict with key not in model output names,
            or if loss is a list with len not equal to model outputs.
    N      ?loss_weightsc                 S   s   g | ]
}|j qS r   r  r  r   r   r   r     r   z(prepare_loss_weights.<locals>.<listcomp>z^When passing a list as loss_weights, it should have one entry per model output. The model has z& outputs, but you passed loss_weights=z+Could not interpret loss_weights argument: z - expected a list of dicts.)loss_weightr   r   r2   r   r   r   r   r  rJ   r;   r=   r   r   r   )r  r  r   r   r   r   r   prepare_loss_weights  s>    



r  c                 C   s   t | ddS )z1Returns whether `layer` is a FeatureLayer or not._is_feature_layerF)r   )layerr   r   r   is_feature_layer  s    r  c                 C   s(   t  o&t| t jjjjt jjt jjfS r   )r   r?  r   r   r   r   rx  rw  r   r   r   r   is_eager_dataset_or_iterator  s    
r  c                 C   s6   t  r|   }nt|  }t jj 	|S r   )
r   r?  _as_serialized_graphnumpyr   	get_valuer   r   GraphDef
FromString)datasetgraph_def_strr   r   r   get_dataset_graph_def  s    r  c                 C   st   t | tjjsJ t| }|jD ]}|jdr  dS q |jj	D ]$}|j
D ]}|jdrJ  dS qJq@td dS )zVerifies that the dataset is shuffled.

    Args:
      x: Dataset passed as an input to the model.

    Returns:
      boolean, whether the input dataset is shuffled or not.
    ShuffleDatasetTznExpected a shuffled dataset but input dataset `x` is not shuffled. Please invoke `shuffle()` on input dataset.F)r   r   r   rx  r  nodeop
startswithlibraryfunctionnode_defry  rz  )r   	graph_defr  r  r   r   r   verify_dataset_shuffled  s    	

r  c                 C   s*   t | tjjjjtjjtjjjjtjjfS r   )r   r   r   r   r   rx  rw  r   r   r   r   is_dataset_or_iterator  s    

r  c                 C   s6   t  rt jjj| }nt jjj| }t| |S )z1Create and initialize an iterator from a dataset.)r   r?  r   r   r   make_one_shot_iteratormake_initializable_iteratorinitialize_iterator)r  iteratorr   r   r   get_iterator  s
    r  c                 C   s$   t  s | j}t|f| d S r   )r   r?  initializerr   get_sessionrun)r  init_opr   r   r   r  #  s    r  c                 C   s    t | }t|\}}}|||fS )zExtract a tuple of tensors `inputs, targets, sample_weight` from a dataset.

    Args:
      dataset: Dataset instance.

    Returns:
      Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
    )r  unpack_iterator_input)r  r  r   r   r   r   r   r   extract_tensors_from_dataset)  s    	r  c                 C   s   z|   }W n tjjy*   tdY n0 t|ttfrxt|dvrRt	d| t|dkrl|\}}d}q|\}}}n|}d}d}|||fS )zConvert a dataset iterator to a tuple of tensors `x, y, sample_weights`.

    Args:
      iterator: Instance of a dataset iterator.

    Returns:
      Tuple of tensors `x, y, weights`. `y` and `weights` entry may be None.
    zkYour dataset iterator ran out of data; Make sure that your dataset can generate required number of samples.)r%  r$  zPlease provide model inputs as a list or tuple of 2 or 3 elements: (input, target) or (input, target, sample_weights) Received %sr%  N)
get_nextr   errorsOutOfRangeErrorrE   r   rJ   rK   r;   r=   )r  next_elementr   r   r   r   r   r   r  7  s*    	
r  r<   c              
   C   s   t |tjjsJ |  r6| jjtjjj	j
kr6dS ttjj|}|tjjjkrn|du rntd|f |dkr|dur|| |kr|dkrtd|||||| ||| f ntd|||||f |du r|dkr|S dS |S )a;  Infers steps_per_epoch needed to loop through a dataset.

    Args:
        model: Keras model instance.
        dataset: Input data of type tf.data.Dataset.
        steps: Number of steps to draw from the dataset (may be None if
          unknown).
        epochs: Number of times to iterate over the dataset.
        steps_name: The string name of the steps argument, either `steps`,
          `validation_steps`, or `steps_per_epoch`. Only used for error message
          formatting.

    Returns:
      Integer or `None`. Inferred number of steps to loop through the dataset.
      `None` is returned if 1) the size of the dataset is unknown and `steps`
      was not specified, or 2) this is multi-worker training and auto sharding
      is enabled.

    Raises:
      ValueError: In case of invalid argument values.
    NzQWhen passing an infinitely repeating dataset, you must specify the `%s` argument.r   r<   zThe dataset you passed contains %s batches, but you passed `epochs=%s` and `%s=%s`, which is a total of %s steps. We cannot draw that many steps from this dataset. We suggest to set `%s=%s`.zThe dataset you passed contains %s batches, but you passed `%s=%s`. We cannot draw that many steps from this dataset. We suggest to set `%s=%s`.)r   r   r   rx  _in_multi_worker_modeoptionsexperimental_distributeauto_shard_policyexperimentalAutoShardPolicyOFFr   r  cardinalityINFINITE_CARDINALITYr=   )r   r  r   epochsr   sizer   r   r   infer_steps_for_dataset\  sN    

r  c                   @   s:   e Zd ZdZdd Zdd ZdddZd	d
 Zdd ZdS )ModelInputszkEncapsulates model inputs.

    Allows for transforming model inputs while keeping the same structure.
    c                 C   s   || _ t| j t| _t| j tttf | _g | _g | _| jrpt	| j 
 D ]"}| j| j |  | j| qJn*tj| j | _dd tt| jD | _d S )Nc                 S   s   g | ]}d |d  qS )zinput_%dr<   r   r   r   r   r   r     s   z(ModelInputs.__init__.<locals>.<listcomp>)_inputsr   r   _is_dictrJ   rK   _is_single_input_flattened_inputs_input_namesr.  r/  rH   r   r   r  r   r;   )r    r   kr   r   r   r!     s    zModelInputs.__init__c                 C   s   | j S )zReturns keys to name inputs by.

        In case inputs provided were a list, tuple or single entry, we make up a
        key 'input_%d'. For dictionary case, we return a sorted list of keys.
        )r  r,   r   r   r   get_input_names  s    zModelInputs.get_input_namesFc                 C   s0  t t| j| jD ]\}\}}t|tttfrNt	|}|j
dkrNt|d}t|tjrdt|jdd  }|dkr|d}t|j}|jrt }tj|||d}nDt|tjrdt|j dd  }|dkrd}tj|||jd}|| j|< q| jrtt| j| jS | jr*|s*| jd S | jS )z4Returns inputs to be set as self.inputs for a model.r<   r   N)Nr<   )rW   r   rw   r   )r  r   r  r  r   rJ   r   r   rG   r   r,  r   r   rK   rW   r   as_dtyperw   r~  r   r6  placeholder
TensorSpecr   r  r   r  )r    return_single_as_listr   r  rF  rW   rw   r   r   r   get_symbolic_inputs  s4    


zModelInputs.get_symbolic_inputsc                 c   s&   t | j| jD ]\}}||fV  qdS )z0An iterable over a dictionary version of inputs.N)r   r  r  )r    r  rF  r   r   r   as_dict  s    zModelInputs.as_dictc                 C   s   | j S )zReturning the inputs as a list.)r  r,   r   r   r   r     s    zModelInputs.as_listN)F)	r.   r/   r0   r1   r!   r  r  r  r   r   r   r   r   r    s   
(r  c                 C   s   dd t t| D S )Nc                 S   s   g | ]}d |d  qS )z	output_%dr<   r   r   r   r   r   r     r   z(generic_output_names.<locals>.<listcomp>)r   r;   )outputs_listr   r   r   generic_output_names  s    r  c                 C   sL   |d }t | tr.| dk r"td||  dkS t | tjjsDtd|| v S )a   Checks if validation should be run this epoch.

    Args:
      validation_freq: Integer or list. If an integer, specifies how many
        training epochs to run before a new validation run is performed. If a
        list, specifies the epochs on which to run validation.
      epoch: Integer, the number of the training epoch just completed.

    Returns:
      Bool, True if validation should be run.

    Raises:
      ValueError: if `validation_freq` is an Integer and less than 1, or if
      it is neither an Integer nor a Sequence.
    r<   z)`validation_freq` can not be less than 1.r   z\`validation_freq` must be an Integer or `collections.abc.Container` (e.g. list, tuple, etc.))r   r   r=   r   r2   	Container)validation_freqepochone_indexed_epochr   r   r   should_run_validation  s    
r  c                 C   s   t | rtdt| d dr:t| d jd d|  }ntt| d d|  }t| d|t| | } }t|d|t|| }}|rt|d|t|| }}nd}| |||||fS )zCSplit input data into train/eval section based on validation_split.zSIf your data is in the form of symbolic tensors, you cannot use `validation_split`.r   rW   r  N)rC  r=   r   r   rW   r;   r   slice_arrays)r   r   sample_weightsrj  split_atval_xval_yval_sample_weightsr   r   r   "split_training_and_validation_data  s&    


r  c                 C   s   t | tjjjjtjjtjjtjfs.t	| ds<| }d}d}nt
| dkr~z| \}}d}W q tyz   | dd  }}}Y q0 n`t
| dkrz| \}}}W q ty   | dd  }}}Y q0 n |rtd|  | dd  }}}|||fS )a^  Unpack validation data based input type.

    The validation data is not touched if its dataset or dataset iterator.
    For other type of input (Numpy or tensor), it will be unpacked into tuple of
    3 which is x, y and sample weights.

    Args:
      validation_data: dataset, dataset iterator, or numpy, tensor tuple.
      raise_if_ambiguous: boolean on whether to fail if validation_data cannot
        be parsed. Otherwise simply return validation_data, None, None and defer
        the decision to the caller.

    Returns:
      tuple of 3, (x, y, sample_weights) for numpy and tensor input.
    __len__Nr%  r$  zWhen passing a `validation_data` argument, it must contain either 2 items (x_val, y_val), or 3 items (x_val, y_val, val_sample_weights), or alternatively it could be a dataset or a dataset or a dataset iterator. However we received `validation_data=%s`)r   r   r   r   r   rw  rx  r   r  r   r;   r=   )validation_dataraise_if_ambiguousr  r  val_sample_weightr   r   r   unpack_validation_data<  sP    
	r  c                   @   s.   e Zd ZdZdddZdd	d
ZdddZdS )TrainingLoopa`  TrainingLoop is a wrapper class around the training logic.

    This class is trying to encapsulate the different logic of fit/eval/predict
    with regard to different data input and model condition.

    Note that TrainingLoop is stateless, which means it doesn't contain any
    internal field and can be reused with different model and inputs.
    Nr<   r:   Tr   c                 K   s
   t  dS )z,Train the model with the inputs and targets.Nr#   )r    r   r   r   r   r  verboser   rj  r  r  r   r   initial_epochsteps_per_epochvalidation_stepsr  kwargsr   r   r   fit  s    zTrainingLoop.fitc	           
      K   s
   t  dS )zKReturns the loss value & metrics values for the model in test
        mode.Nr#   )
r    r   r   r   r   r  r   r   r   r  r   r   r   evaluate  s    zTrainingLoop.evaluatec                 K   s
   t  d S r   r#   )r    r   r   r   r  r   r   r  r   r   r   predict  s    
zTrainingLoop.predict)NNNr<   r<   Nr:   NTNNr   NNr<   )NNNr<   NNN)Nr   NN)r.   r/   r0   r1   r  r  r  r   r   r   r   r  x  s<                  
       
    r  )T)NNr   )N)NTr   )N)FF)NNN)F)NN)NNN)N)Tr   )NNN)N)N)N)r<   r   )T)Zr1   r2   ro   r   r  multiprocessing.poolrk   r|   r   r  rG   tensorflow.compat.v2r   v2r   kerasr   r   r   r   r   r  keras.utilsr   r   r   r	   tensorflow.python.platformr
   ry  r   objectABCMetar   r5   rT   rY   rb   rc   rn   rj   rr   rs   r   r   r   r   r   r   r   r   r   r   r   r  r!  r>  rC  r@  r
  r  r`  rh  rk  rr  rs  r   r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r  r   r   r   r   <module>   s   1$54/i8

%

 
 %;
DM  
f 
 
$
- 
#
+
 
=
	
20
4
& 
HW 
<