B
    ӻdc                 @   s   d Z ddlZddlmZ ddlmZ ddlmZ	 ddl
mZ ddl
mZ ddl
mZ dd	lmZ dd
lmZ ddlmZ G dd deZG dd deZG dd deZdd Zdd Zdd Zdd Zdd Zdd Zdd Zd d! Zd"d# ZdS )$zUtilites for `Model.compile`.    N)distribution_strategy_context)losses)metrics)generic_utils)losses_utils)tf_utils)	array_ops)math_ops)nestc               @   sB   e Zd ZdZdddZdd Zdd Zd	d
 Zdd Zdd Z	dS )	ContainerzBase Container class.Nc             C   s
   || _ d S )N)_output_names)selfoutput_names r   ^/var/www/html/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/compile_utils.py__init__!   s    zContainer.__init__c             C   s   | j d krt|| _ d S )N)r   create_pseudo_output_names)r   y_predr   r   r   build$   s    
zContainer.buildc                sD   t || j  t|  t s@t|r@t fdd|  S )a  Convenience method to conform `struct` to `outputs` structure.

    Mappings performed:

    (1) Map a dict to a list of outputs, using the output names.
    (2) Fill missing keys in a dict w/ `None`s.
    (3) Map a single item to all outputs.

    Args:
      outputs: Model predictions.
      struct: Arbitrary nested structure (e.g. of labels, sample_weights,
        losses, or metrics).

    Returns:
      Mapping of `struct` to `outputs` structure.
    c                s    S )Nr   )_)structr   r   <lambda>?       z/Container._conform_to_outputs.<locals>.<lambda>)map_to_output_namesr   map_missing_dict_keysr
   	is_nestedmap_structure)r   outputsr   r   )r   r   _conform_to_outputs*   s
    
zContainer._conform_to_outputsc                sD    sS tt|dkfdd t fdd|S )a  Determines if losses / metrics should be applied to all outputs.

    NOTE: This method should only be called for Metrics / Losses, not for
    y_true / sample_weight.

    Args:
      outputs: Model predictions.
      objects: Arbitrary nested structure (e.g. of losses or metrics)

    Returns:
      Arbitrary nested structure of objects, maybe copied to each output.

    Applies a Loss / Metric to all outputs.
       c                  s   rt j S  S )N)r
   r   _copy_objectr   )objectsr   should_copy_objectsr   r   _broadcast_fnY   s    z<Container._maybe_broadcast_to_outputs.<locals>._broadcast_fnc                s     S )Nr   )r   )r#   r   r   r   ^   r   z7Container._maybe_broadcast_to_outputs.<locals>.<lambda>)_should_broadcastlenr
   flattenr   )r   r   r!   r   )r#   r!   r   r"   r   _maybe_broadcast_to_outputsB   s
    
z%Container._maybe_broadcast_to_outputsc             C   s   t d S )N)NotImplementedError)r   r!   r   r   r   r$   `   s    zContainer._should_broadcastc             C   s   t d S )N)r(   )r   objr   r   r   r    c   s    zContainer._copy_object)N)
__name__
__module____qualname____doc__r   r   r   r'   r$   r    r   r   r   r   r      s   
r   c                   sx   e Zd ZdZd fdd	Zedd Z fddZed	d
 Zdd Z	dddZ
dd Zdd Zdd Zdd Z  ZS )LossesContainerz7A container class for losses passed to `Model.compile`.Nc                sH   t t| j|d || _|| _|| _|| _d | _tj	dd| _
d| _d S )N)r   loss)nameF)superr.   r   Z_user_lossesZ_user_loss_weights_losses_loss_weights_per_output_metricsmetrics_modMean_loss_metric_built)r   r   Zloss_weightsr   )	__class__r   r   r   j   s    zLossesContainer.__init__c             C   s,   | j s
g S dd t| jD }| jg| S )zPer-output loss metrics.c             S   s   g | ]}|d k	r|qS )Nr   ).0
metric_objr   r   r   
<listcomp>}   s    z+LossesContainer.metrics.<locals>.<listcomp>)r8   r
   r&   r4   r7   )r   Zper_output_metricsr   r   r   r   w   s    zLossesContainer.metricsc                s   t t| | | || j| _| || j| _t| j| j| _t	| j| _| || j
| _
| || j
| _
t	| j
| _
|   d| _dS )zOne-time setup of loss objects.TN)r1   r.   r   r'   r2   r   r
   r   _get_loss_objectr&   r3   _create_metricsr8   )r   r   )r9   r   r   r      s    zLossesContainer.buildc             C   s   | j S )N)r8   )r   r   r   r   built   s    zLossesContainer.builtc             C   sj   t | jdkrdg| _nNg | _xFt| j| jD ]4\}}|dkrL| jd q.| jt|d  q.W dS )zBCreates per-output loss metrics, but only for multi-output Models.r   NZ_loss)r%   r   r4   zipr2   appendr5   r6   )r   loss_objoutput_namer   r   r   r>      s    
zLossesContainer._create_metricsc             C   s  |  ||}|  ||}| js(| | t|}t|}t|}g }g }d}|||| j| j| jf}xt| D ]\}	}
}}}}|	dkst|dkrqtt	|	|
|\}	}
}t
|
|t|
}||	|
|d}|}|jtjjkr|t j9 }|dkrt|	r|	 }nt|	d }|dk	r.|j||d |dk	rH||9 }||9 }|jtjjksh|jtjjkrrt|}|| || qtW |rt|}t|}|| |t| |r t|}t|}| j j||d t|}t|}|S tj!ddS dS )aK  Computes the overall loss.

    Args:
      y_true: An arbitrary structure of Tensors representing the ground truth.
      y_pred: An arbitrary structure of Tensors representing a Model's outputs.
      sample_weight: An arbitrary structure of Tensors representing the
        per-sample loss weights. If one Tensor is passed, it is used for all
        losses. If multiple Tensors are passed, the structure should match
        `y_pred`.
      regularization_losses: Additional losses to be added to the total loss.

    Returns:
      Tuple of `(total_loss, per_output_loss_list)`
    N)sample_weightr   r   )shape)"r   r8   r   r
   r&   r2   r3   r4   r@   match_dtype_and_rank
apply_maskget_maskZ	reductionr   ZReductionV2ZSUM
ds_contextZget_strategyZnum_replicas_in_syncr   Z	is_raggedZnrowsr   rE   update_stateZSUM_OVER_BATCH_SIZEZAUTOZscale_loss_for_distributionrA   Zcast_losses_to_common_dtyper	   Zadd_nr7   Zzeros)r   y_truer   rD   Zregularization_lossesZloss_valuesZloss_metric_valuesZ	batch_dimzip_argsy_ty_pswrB   Zloss_weightr;   Z
loss_valueZloss_metric_valueZreg_lossZtotal_loss_metric_valueZ
total_lossr   r   r   __call__   sf    















zLossesContainer.__call__c             C   s@   | j s
dS | jgt| j }x|D ]}|dk	r$|  q$W dS )z!Resets the state of loss metrics.N)r8   r7   r
   r&   r4   reset_state)r   r   r;   r   r   r   rQ      s    
zLossesContainer.reset_statec             C   sX   |dkrdS t |}t|t jsNt|}|dkr@td|t j||d}d|_|S )a  Returns a `Loss` object.

    Converts the user-supplied loss to a `Loss` object. Also allows
    `SUM_OVER_BATCH_SIZE` reduction to be used for this loss.

    Args:
      loss: A string, function, or `Loss` object.

    Returns:
      A `Loss` object.
    Nz$Loss should be a callable, found: {})r0   T)	
losses_modget
isinstanceLossget_custom_object_name
ValueErrorformatZLossFunctionWrapper_allow_sum_over_batch_size)r   r/   Z	loss_namer   r   r   r=     s    
z LossesContainer._get_loss_objectc             C   s   t | S )N)r
   r   )r   r)   r   r   r   r$     s    z!LossesContainer._should_broadcastc             C   s   |S )Nr   )r   r)   r   r   r   r       s    zLossesContainer._copy_object)NN)NN)r*   r+   r,   r-   r   propertyr   r   r?   r>   rP   rQ   r=   r$   r    __classcell__r   r   )r9   r   r.   g   s    
T	r.   c                   s   e Zd ZdZd  fdd	Zedd Zedd	 Zed
d Z fddZ	edd Z
dd Zdd Zd!ddZdd Zdd Zdd Zdd Zdd Z  ZS )"MetricsContainerz8A container class for metrics passed to `Model.compile`.NFc                s:   t t| j|d || _|| _|| _|| _d| _|| _dS )a  Initializes a container for metrics.

    Arguments:
      metrics: see the `metrics` argument from `tf.keras.Model.compile`.
      weighted_metrics: see the `weighted_metrics` argument from
        `tf.keras.Model.compile`.
      output_names: A list of strings of names of outputs for the model.
      from_serialized: Whether the model being compiled is from a serialized
        model.  Used to avoid redundantly applying pre-processing renaming
        steps.
    )r   FN)	r1   r\   r   _user_metrics_user_weighted_metrics_metrics_weighted_metricsr8   _from_serialized)r   r   weighted_metricsr   Zfrom_serialized)r9   r   r   r   '  s    zMetricsContainer.__init__c             C   s   | j s
g S | jS )zAll metrics in this container.)r8   _metrics_in_order)r   r   r   r   r   @  s    zMetricsContainer.metricsc             C   s   | j s
dS t| jS )zDMetrics in this container that should not be passed `sample_weight`.N)r8   r
   r&   r_   )r   r   r   r   unweighted_metricsG  s    z#MetricsContainer.unweighted_metricsc             C   s   | j s
dS t| jS )z@Metrics in this container that should be passed `sample_weight`.N)r8   r
   r&   r`   )r   r   r   r   rb   N  s    z!MetricsContainer.weighted_metricsc                s   t t| | | || j| _| || j| _| || j| _| || j| _t|}t|}t| j| _t| j| _t	|| j
| j||| _t	|| j
| j||| _tj|| jdd| _tj|| jdd| _| js|   |   d| _dS )z!One-time setup of metric objects.F)Zcheck_typesTN)r1   r\   r   r'   r_   r   r`   r
   Zlist_to_tupleZmap_structure_up_to_get_metric_objectsZflatten_up_tora   _set_metric_names_create_ordered_metricsr8   )r   r   rK   )r9   r   r   r   U  s0    




zMetricsContainer.buildc             C   s   | j S )N)r8   )r   r   r   r   r?   }  s    zMetricsContainer.builtc       	      C   s"  t  }t| jdk}| j| j| jf}xt| D ]\}}}xP|D ]H}|dkrLq>|r`|d |j |_|j|krztd|j|	|j q>W x|D ]}|dkrq|r|d |j |kr|d |j |_q|d |j |_n|j|krd|j |_|j|kr
td|j|	|j qW q.W dS )zSets unique metric names.r   Nr   z(Found two metrics with the same name: {}Z
_weighted_Z	weighted_)
setr%   r   r_   r`   r@   _namerW   rX   add)	r   Zmetric_namesZis_multi_outputrL   rC   output_metricsZweighted_output_metricsmwmr   r   r   rf     s4    





z"MetricsContainer._set_metric_namesc             C   sv   g | _ xjt| j| jD ]X\}}x&t|D ]}|dk	r*| j | q*W x&t|D ]}|dk	rR| j | qRW qW dS )zICache the flat order needed when returning metrics, for backwards compat.N)rc   r@   r_   r`   r
   r&   rA   )r   rk   Zoutput_weighted_metricsrl   rm   r   r   r   rg     s    z(MetricsContainer._create_ordered_metricsc             C   s0  |  ||}|  ||}| js*| || t|}|dk	rFt|ng }t|}|||| j| jf}xt| D ]\}}}}}	|dksptdd |D rtdd |	D rqpt	|||\}}}t
|}
t|||
}x&|D ]}|dkrq|j|||
d qW x,|	D ]$}|dkrq |j|||d q W qpW dS )z(Updates the state of per-output metrics.Nc             s   s   | ]}|d kV  qd S )Nr   )r:   rl   r   r   r   	<genexpr>  s    z0MetricsContainer.update_state.<locals>.<genexpr>c             s   s   | ]}|d kV  qd S )Nr   )r:   rm   r   r   r   rn     s    )rD   )r   r8   r   r
   r&   r_   r`   r@   allrF   rH   rG   rJ   )r   rK   r   rD   rL   rM   rN   rO   Zmetric_objsZweighted_metric_objsmaskr;   Zweighted_metric_objr   r   r   rJ     s0    





zMetricsContainer.update_statec             C   sL   | j r| j}nt| jt| j }x |D ]}t|tjr,|	  q,W dS )z4Resets the state of all `Metric`s in this container.N)
r8   rc   r
   r&   r]   r^   rT   r5   MetricrQ   )r   r   r;   r   r   r   rQ     s    

zMetricsContainer.reset_statec                s    t |} fdd|D S )z2Convert user-supplied metrics to `Metric` objects.c                s   g | ]}  |qS r   )_get_metric_object)r:   rl   )r   rN   rM   r   r   r<     s    z8MetricsContainer._get_metric_objects.<locals>.<listcomp>)r
   r&   )r   r   rM   rN   r   )r   rN   rM   r   re     s    
z$MetricsContainer._get_metric_objectsc             C   s0  |dkrdS t | dkr(t|}nt|j }t|j }|j d }|j d }|dk}	||k p~|dko~|dk}
t | dkr|	rtj}q|
rtj}qtj	}n|	rtj
}n|
rtj}ntj}t|tjrd|_t|tjs,t|t r|}n t|}|dkrtd|tj||d}|S )	zConverts user-supplied metric to a `Metric` object.

    Args:
      metric: A string, function, or `Metric` object.
      y_t: Sample of label.
      y_p: Sample of output.

    Returns:
      A `Metric` object.
    N)accuracyaccZcrossentropyZcer   )rs   rt   Tz&Metric should be a callable, found: {})r0   )strlowerr5   rS   r%   rE   as_listZbinary_accuracyZsparse_categorical_accuracyZcategorical_accuracyZbinary_crossentropyZsparse_categorical_crossentropyZcategorical_crossentropyrT   rR   rU   rY   rq   rV   rW   rX   ZMeanMetricWrapper)r   ZmetricrM   rN   r;   Zy_t_rankZy_p_rankZy_t_last_dimZy_p_last_dimZ	is_binaryZis_sparse_categoricalZmetric_namer   r   r   rr     s@    

z#MetricsContainer._get_metric_objectc             C   s0   t |sdS t|ttfo.tdd |D  S )NTc             s   s   | ]}t |V  qd S )N)r
   r   )r:   or   r   r   rn   (  s    z5MetricsContainer._should_broadcast.<locals>.<genexpr>)r
   r   rT   listtupleany)r   r)   r   r   r   r$   "  s    
z"MetricsContainer._should_broadcastc             C   s    t |tjr|j| S |S )N)rT   r5   rq   r9   from_configZ
get_config)r   r)   r   r   r   r    *  s    zMetricsContainer._copy_object)NNNF)N)r*   r+   r,   r-   r   rZ   r   rd   rb   r   r?   rf   rg   rJ   rQ   re   rr   r$   r    r[   r   r   )r9   r   r\   $  s     ($
";r\   c             C   s   t | ddS )z2Create pseudo output names for a subclassed Model.Zoutput_)prefix)_create_pseudo_names)r   r   r   r   r   0  s    r   c             C   s   t | ddS )z1Create pseudo input names for a subclassed Model.Zinput_)r~   )r   )inputsr   r   r   create_pseudo_input_names5  s    r   c             C   sz   dd }t t| }t||}g }xN|D ]F}|s>|d }n*ddd |D }t|d trh|| }|| q,W |S )a%  Creates pseudo {input | output} names for subclassed Models.

  Warning: this function should only be used to define default
  names for `Metics` and `SavedModel`. No other use cases should
  rely on a `Model`'s input or output names.

  Example with dict:

  `{'a': [x1, x2], 'b': x3}` becomes:
  `['a_1', 'a_2', 'b']`

  Example with list:

  `[x, y]` becomes:
  `['output_1', 'output_2']`

  Args:
    tensors: `Model`'s outputs or inputs.
    prefix: 'output_' for outputs, 'input_' for inputs.

  Returns:
    Flattened list of pseudo names.
  c             S   s   t | tr| d S | S )Nr   )rT   int)Zeler   r   r   	one_indexS  s    
z'_create_pseudo_names.<locals>.one_index1r   c             s   s   | ]}t |V  qd S )N)rv   )r:   pr   r   r   rn   `  s    z'_create_pseudo_names.<locals>.<genexpr>r   )rz   r
   Zyield_flat_pathsr   joinrT   r   rA   )Ztensorsr~   r   Z
flat_pathsnamespathr0   r   r   r   r   :  s    

r   c                s   t |  }| o2t| ttfo2tdd | D  }|s<|rt tr|pPt| }t   fdd|D } rt	d
  |t|dkr|d S |S  S dS )	a  Maps a dict to a list using `output_names` as keys.

  This is a convenience feature only. When a `Model`'s outputs
  are a list, you can specify per-output losses and metrics as
  a dict, where the keys are the output names. If you specify
  per-output losses and metrics via the same structure as the
  `Model`'s outputs (recommended), no mapping is performed.

  For the Functional API, the output names are the names of the
  last layer of each output. For the Subclass API, the output names
  are determined by `create_pseudo_output_names` (For example:
  `['output_1', 'output_2']` for a list of outputs).

  This mapping preserves backwards compatibility for `compile` and
  `fit`.

  Args:
    y_pred: Sample outputs of the Model, to determine if this convenience
      feature should be applied (`struct` is returned unmodified if `y_pred`
      isn't a flat list).
    output_names: List. The names of the outputs of the Model.
    struct: The structure to map.

  Returns:
    `struct` mapped to a list in same order as `output_names`.
  c             s   s   | ]}t |V  qd S )N)r
   r   )r:   rN   r   r   r   rn     s    z&map_to_output_names.<locals>.<genexpr>c                s   g | ]}  |d qS )N)pop)r:   r0   )r   r   r   r<     s    z'map_to_output_names.<locals>.<listcomp>zRFound unexpected keys that do not correspond to any Model output: {}. Expected: {}r   r   N)r
   r   rT   rz   r{   r|   dictr   copyrW   rX   keysr%   )r   r   r   Zsingle_outputZoutputs_are_flat_listZ
new_structr   )r   r   r   g  s    
r   c             C   s>   t | trt |ts|S x |  D ]}||kr"d||< q"W |S )z@Replaces missing dict keys in `struct` with `None` placeholders.N)rT   r   r   )r   r   kr   r   r   r     s    r   c             C   s   | j jdkr&|j jdkr&tj| dd} |dk	rT|j jdkrT|j jdkrTtj|dd}| jjrd|jjst| jjr|jjrt| |j} |dk	rt||j}| ||fS )z$Match dtype and rank of predictions.r      ru   )ZaxisN)	rE   Zrankr   Zexpand_dims_v2dtypeZis_floating
is_integerr	   cast)rM   rN   rO   r   r   r   rF     s    rF   c             C   s   t | ddS )zReturns Keras mask from tensor.Z_keras_maskN)getattr)rN   r   r   r   rH     s    rH   c             C   sD   |dk	r@t || j}|dk	r<tj||d\}}}||9 }n|}|S )z2Applies any mask on predictions to sample weights.N)rD   )r	   r   r   r   Zsqueeze_or_expand_dimensions)rN   rO   rp   r   r   r   r   rG     s    
rG   c             C   s@   t | dr| jS t | dr | jS t | dr8t| jjS dS dS )zReturns the name to use for a custom loss or metric callable.

  Args:
    obj: Custom loss of metric callable

  Returns:
    Name to use, or `None` if the object was not recognized.
  r0   r*   r9   N)hasattrr0   r*   r   Zto_snake_caser9   )r)   r   r   r   rV     s    	


rV   ) r-   r   Ztensorflow.python.distributer   rI   Ztensorflow.python.kerasr   rR   r   r5   Ztensorflow.python.keras.utilsr   r   r   Ztensorflow.python.opsr   r	   Ztensorflow.python.utilr
   objectr   r.   r\   r   r   r   r   r   rF   rH   rG   rV   r   r   r   r   <module>   s2   I >  -/
