B
    ӻdt                 @   s0  d Z ddlmZ ddlmZ ddlmZ ddlmZ ddl	m
Z
 ddl	mZ ddl	mZ dd	lmZ dd
lmZ ddlmZ G dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deejZeZeZ eZ!eZ"eZ#eZ$eZ%dS )znLegacy v1 optimizer classes.

For more examples see the base class `tf.compat.v1.keras.optimizers.Optimizer`.
    )distribution_strategy_context)backprop)ops)backend)clip_ops)math_ops)	state_ops)base)training_util)nestc               @   sX   e Zd ZdZdd ZdZdd Zdd Zd	d
 Zdd Z	dd Z
dd Zedd ZdS )	Optimizera  Abstract optimizer base class.

  Note: this is the parent class of all optimizers, not an actual optimizer
  that can be used for training models.

  All Keras optimizers support the following keyword arguments:

      clipnorm: float >= 0. Gradients will be clipped
          when their L2 norm exceeds this value.
      clipvalue: float >= 0. Gradients will be clipped
          when their absolute value exceeds this value.
  c             K   sj   ddh}xD|D ]<}||kr*t dt| || dk rtd||| qW | j| g | _g | _d S )Nclipnorm	clipvaluez1Unexpected keyword argument passed to optimizer: r   zExpected {} >= 0, received: {})	TypeErrorstr
ValueErrorformat__dict__updateupdatesweights)selfkwargsZallowed_kwargsk r   V/var/www/html/venv/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v1.py__init__0   s    
zOptimizer.__init__Fc             C   s   t dS )zCreates and sets all optimizer weights.

    Args:
      params: list or tuple of `Variable` objects that will be minimized
        using this optimizer.

    Returns:
      Specific weight values that are used in `get_updates`
    N)NotImplementedError)r   paramsr   r   r   _create_all_weightsA   s    
zOptimizer._create_all_weightsc             C   s   t d S )N)r   )r   lossr   r   r   r   get_updatesM   s    zOptimizer.get_updatesc                sb   t ||}tdd |D r&tdt drB fdd|D }t dr^ fdd|D }|S )	a3  Returns gradients of `loss` with respect to `params`.

    Args:
        loss: Loss tensor.
        params: List of variables.

    Returns:
        List of gradient tensors.

    Raises:
        ValueError: In case any gradient cannot be computed (e.g. if gradient
          function not implemented).
    c             s   s   | ]}|d kV  qd S )Nr   ).0gr   r   r   	<genexpr>_   s    z*Optimizer.get_gradients.<locals>.<genexpr>zAn operation has `None` for gradient. Please make sure that all of your ops have a gradient defined (i.e. are differentiable). Common ops without gradient: backend.argmax, backend.round, backend.eval.r   c                s   g | ]}t | jqS r   )r   Zclip_by_normr   )r"   r#   )r   r   r   
<listcomp>f   s    z+Optimizer.get_gradients.<locals>.<listcomp>r   c                s    g | ]}t | j  jqS r   )r   Zclip_by_valuer   )r"   r#   )r   r   r   r%   i   s   )r   Z	gradientsanyr   hasattr)r   r    r   gradsr   )r   r   get_gradientsP   s    


zOptimizer.get_gradientsc             C   s   | j }t|t|kr>tdtt| d tt| d g }t|}xTt|||D ]D\}}}|j|jkrtdt|j d t|j |||f qZW t	| dS )a  Sets the weights of the optimizer, from Numpy arrays.

    Should only be called after computing the gradients
    (otherwise the optimizer has no weights).

    Args:
        weights: a list of Numpy arrays. The number of arrays and their shape
          must match number of the dimensions of the weights of the optimizer
          (i.e. it should match the output of `get_weights`).

    Raises:
        ValueError: in case of incompatible weight shapes.
    z%Length of the specified weight list (z9) does not match the number of weights of the optimizer ()zOptimizer weight shape z+ not compatible with provided weight shape N)
r   lenr   r   r   batch_get_valuezipshapeappendZbatch_set_value)r   r   r   Zweight_value_tuplesZparam_valuespvpwr   r   r   set_weightsn   s    &
zOptimizer.set_weightsc             C   s   t | jS )zmReturns the current value of the weights of the optimizer.

    Returns:
        A list of numpy arrays.
    )r   r,   r   )r   r   r   r   get_weights   s    zOptimizer.get_weightsc             C   s0   i }t | dr| j|d< t | dr,| j|d< |S )Nr   r   )r'   r   r   )r   configr   r   r   
get_config   s    



zOptimizer.get_configc             C   s
   | f |S )Nr   )clsr5   r   r   r   from_config   s    zOptimizer.from_configN)__name__
__module____qualname____doc__r   Z_HAS_AGGREGATE_GRADr   r!   r)   r3   r4   r6   classmethodr8   r   r   r   r   r   "   s   r   c                   s>   e Zd ZdZd fdd	Zdd Zd	d
 Z fddZ  ZS )SGDa  Stochastic gradient descent optimizer.

  Includes support for momentum,
  learning rate decay, and Nesterov momentum.

  Args:
      lr: float >= 0. Learning rate.
      momentum: float >= 0. Parameter that accelerates SGD in the relevant
        direction and dampens oscillations.
      decay: float >= 0. Learning rate decay over each update.
      nesterov: boolean. Whether to apply Nesterov momentum.
  {Gz?        Fc          	      s~   t t| jf | t| jjH tjdddd| _tj|dd| _	tj|dd| _
tj|dd| _W d Q R X || _|| _d S )	Nr   int64
iterations)dtypenamelr)rD   momentumdecay)superr>   r   r   
name_scope	__class__r9   variablerB   rE   rF   rG   initial_decaynesterov)r   rE   rF   rG   rM   r   )rJ   r   r   r      s    zSGD.__init__c             C   s.   dd |D }dd |D }| j g| | _|S )Nc             S   s   g | ]}t |qS r   )r   	int_shape)r"   r1   r   r   r   r%      s    z+SGD._create_all_weights.<locals>.<listcomp>c             S   s   g | ]}t |qS r   )r   zeros)r"   r.   r   r   r   r%      s    )rB   r   )r   r   shapesmomentsr   r   r   r      s    zSGD._create_all_weightsc          
   C   s   |  ||}t| jdg| _| j}| jdkrV|dd| jt	| jt
| j    }| |}xt|||D ]\}}}| j| ||  }	| jt||	 | jr|| j|	  ||  }
n||	 }
t|dd d k	r||
}
| jt||
 qnW | jS )N   r   g      ?
constraint)r)   r   
assign_addrB   r   rE   rL   rG   r   castr   rC   r   r-   rF   r/   assignrM   getattrrS   )r   r    r   r(   rE   rQ   r1   r#   mvnew_pr   r   r   r!      s(    


zSGD.get_updatesc                s^   t t| jt t| jt t| j| jd}tt| 	 }t
t| t|  S )N)rE   rF   rG   rM   )floatr   	get_valuerE   rF   rG   rM   rH   r>   r6   dictlistitems)r   r5   base_config)rJ   r   r   r6      s    
zSGD.get_config)r?   r@   r@   F)	r9   r:   r;   r<   r   r   r!   r6   __classcell__r   r   )rJ   r   r>      s
   
r>   c                   s>   e Zd ZdZd fdd	Zdd	 Zd
d Z fddZ  ZS )RMSpropa}  RMSProp optimizer.

  It is recommended to leave the parameters of this optimizer
  at their default values
  (except the learning rate, which can be freely tuned).

  Args:
    lr: float >= 0. Learning rate.
    rho: float >= 0.
    epsilon: float >= 0. Fuzz factor.
      If `None`, defaults to `backend.epsilon()`.
    decay: float >= 0. Learning rate decay over each update.
  MbP??N        c          	      s   t t| jf | t| jjH tj|dd| _tj|dd| _	tj|dd| _
tjdddd| _W d Q R X |d kr~t }|| _|| _d S )	NrE   )rD   rhorG   r   rA   rB   )rC   rD   )rH   rb   r   r   rI   rJ   r9   rK   rE   rf   rG   rB   epsilonrL   )r   rE   rf   rg   rG   r   )rJ   r   r   r      s    zRMSprop.__init__c             C   s   dd |D }|| _ |S )Nc             S   s&   g | ]}t jt |t |d qS ))rC   )r   rO   rN   rC   )r"   r1   r   r   r   r%     s   z/RMSprop._create_all_weights.<locals>.<listcomp>)r   )r   r   accumulatorsr   r   r   r     s    zRMSprop._create_all_weightsc          
   C   s   |  ||}| |}t| jdg| _| j}| jdkr`|dd| jt	
| jt| j    }xt|||D ]\}}}| j| d| j t	|  }	| jt||	 ||| t|	| j   }
t|dd d k	r||
}
| jt||
 qnW | jS )NrR   r   g      ?rS   )r)   r   r   rT   rB   r   rE   rL   rG   r   rU   r   rC   r-   rf   squarer/   rV   sqrtrg   rW   rS   )r   r    r   r(   rh   rE   r1   r#   anew_arZ   r   r   r   r!   	  s$    


zRMSprop.get_updatesc                s^   t t| jt t| jt t| j| jd}tt| 	 }t
t| t|  S )N)rE   rf   rG   rg   )r[   r   r\   rE   rf   rG   rg   rH   rb   r6   r]   r^   r_   )r   r5   r`   )rJ   r   r   r6   #  s    
zRMSprop.get_config)rc   rd   Nre   )	r9   r:   r;   r<   r   r   r!   r6   ra   r   r   )rJ   r   rb      s
   rb   c                   s>   e Zd ZdZd fdd	Zdd Zd	d
 Z fddZ  ZS )Adagrada  Adagrad optimizer.

  Adagrad is an optimizer with parameter-specific learning rates,
  which are adapted relative to how frequently a parameter gets
  updated during training. The more updates a parameter receives,
  the smaller the updates.

  It is recommended to leave the parameters of this optimizer
  at their default values.

  # Arguments
      lr: float >= 0. Initial learning rate.
      epsilon: float >= 0. If `None`, defaults to `backend.epsilon()`.
      decay: float >= 0. Learning rate decay over each update.

  # References
      - [Adaptive Subgradient Methods for Online Learning and Stochastic
      Optimization](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
  {Gz?N        c          	      s~   t t| jf | t| jj8 tj|dd| _tj|dd| _	tjdddd| _
W d Q R X |d krnt }|| _|| _d S )NrE   )rD   rG   r   rA   rB   )rC   rD   )rH   rm   r   r   rI   rJ   r9   rK   rE   rG   rB   rg   rL   )r   rE   rg   rG   r   )rJ   r   r   r   C  s    zAdagrad.__init__c             C   s&   dd |D }dd |D }|| _ |S )Nc             S   s   g | ]}t |qS r   )r   rN   )r"   r1   r   r   r   r%   O  s    z/Adagrad._create_all_weights.<locals>.<listcomp>c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%   P  s    )r   )r   r   rP   rh   r   r   r   r   N  s    zAdagrad._create_all_weightsc          
   C   s   |  ||}| |}t| jdg| _| j}| jdkr`|dd| jt	
| jt| j    }xt|||D ]v\}}}|t	| }	| jt||	 ||| t|	| j   }
t|dd d k	r||
}
| jt||
 qnW | jS )NrR   r   g      ?rS   )r)   r   r   rT   rB   r   rE   rL   rG   r   rU   r   rC   r-   ri   r/   rV   rj   rg   rW   rS   )r   r    r   r(   rh   rE   r1   r#   rk   rl   rZ   r   r   r   r!   T  s$    


zAdagrad.get_updatesc                sP   t t| jt t| j| jd}tt|  }t	t
| t
|  S )N)rE   rG   rg   )r[   r   r\   rE   rG   rg   rH   rm   r6   r]   r^   r_   )r   r5   r`   )rJ   r   r   r6   n  s
    
zAdagrad.get_config)rn   Nro   )	r9   r:   r;   r<   r   r   r!   r6   ra   r   r   )rJ   r   rm   .  s
   rm   c                   s>   e Zd ZdZd fdd	Zdd	 Zd
d Z fddZ  ZS )Adadeltaa0  Adadelta optimizer.

  Adadelta is a more robust extension of Adagrad
  that adapts learning rates based on a moving window of gradient updates,
  instead of accumulating all past gradients. This way, Adadelta continues
  learning even when many updates have been done. Compared to Adagrad, in the
  original version of Adadelta you don't have to set an initial learning
  rate. In this version, initial learning rate and decay factor can
  be set, as in most other Keras optimizers.

  It is recommended to leave the parameters of this optimizer
  at their default values.

  Arguments:
    lr: float >= 0. Initial learning rate, defaults to 1.
        It is recommended to leave it at the default value.
    rho: float >= 0. Adadelta decay factor, corresponding to fraction of
        gradient to keep at each time step.
    epsilon: float >= 0. Fuzz factor.
      If `None`, defaults to `backend.epsilon()`.
    decay: float >= 0. Initial learning rate decay.

  References:
      - [Adadelta - an adaptive learning rate
      method](http://arxiv.org/abs/1212.5701)
        ?ffffff?N        c          	      s   t t| jf | t| jj8 tj|dd| _tj|dd| _	tjdddd| _
W d Q R X |d krnt }|| _|| _|| _d S )NrE   )rD   rG   r   rA   rB   )rC   rD   )rH   rp   r   r   rI   rJ   r9   rK   rE   rG   rB   rg   rf   rL   )r   rE   rf   rg   rG   r   )rJ   r   r   r     s    zAdadelta.__init__c             C   s<   dd |D }dd |D }dd |D }|| | _ ||fS )Nc             S   s   g | ]}t |qS r   )r   rN   )r"   r1   r   r   r   r%     s    z0Adadelta._create_all_weights.<locals>.<listcomp>c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%     s    c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%     s    )r   )r   r   rP   rh   delta_accumulatorsr   r   r   r     s
    
zAdadelta._create_all_weightsc          
   C   sL  |  ||}t| jdg| _| |\}}| j}| jdkrd|dd| jt	
| jt| j    }xt||||D ]\}}}	}
| j|	 d| j t	|  }| jt|	| |t|
| j  t|| j  }|||  }t|dd d k	r||}| jt|| | j|
 d| j t	|  }| jt|
| qtW | jS )NrR   r   g      ?rS   )r)   r   rT   rB   r   r   rE   rL   rG   r   rU   r   rC   r-   rf   ri   r/   rV   rj   rg   rW   rS   )r   r    r   r(   rh   rt   rE   r1   r#   rk   Zd_arl   r   rZ   Znew_d_ar   r   r   r!     s,    

zAdadelta.get_updatesc                sT   t t| j| jt t| j| jd}tt| 	 }t
t| t|  S )N)rE   rf   rG   rg   )r[   r   r\   rE   rf   rG   rg   rH   rp   r6   r]   r^   r_   )r   r5   r`   )rJ   r   r   r6     s    
zAdadelta.get_config)rq   rr   Nrs   )	r9   r:   r;   r<   r   r   r!   r6   ra   r   r   )rJ   r   rp   x  s
   "rp   c                   s>   e Zd ZdZd fdd		Zd
d Zdd Z fddZ  ZS )Adama  Adam optimizer.

  Default parameters follow those provided in the original paper.

  Args:
    lr: float >= 0. Learning rate.
    beta_1: float, 0 < beta < 1. Generally close to 1.
    beta_2: float, 0 < beta < 1. Generally close to 1.
    epsilon: float >= 0. Fuzz factor.
      If `None`, defaults to `backend.epsilon()`.
    decay: float >= 0. Learning rate decay over each update.
    amsgrad: boolean. Whether to apply the AMSGrad variant of this algorithm
      from the paper "On the Convergence of Adam and Beyond".
  MbP??+?N        Fc          	      s   t t| jf | t| jjX tjdddd| _tj|dd| _	tj|dd| _
tj|dd| _tj|d	d| _W d Q R X |d krt }|| _|| _|| _d S )
Nr   rA   rB   )rC   rD   rE   )rD   beta_1beta_2rG   )rH   ru   r   r   rI   rJ   r9   rK   rB   rE   rz   r{   rG   rg   rL   amsgrad)r   rE   rz   r{   rg   rG   r|   r   )rJ   r   r   r     s    zAdam.__init__c             C   s`   dd |D }dd |D }| j r2dd |D }ndd |D }| jg| | | | _|||fS )Nc             S   s&   g | ]}t jt |t |d qS ))rC   )r   rO   rN   rC   )r"   r1   r   r   r   r%     s   z,Adam._create_all_weights.<locals>.<listcomp>c             S   s&   g | ]}t jt |t |d qS ))rC   )r   rO   rN   rC   )r"   r1   r   r   r   r%     s   c             S   s&   g | ]}t jt |t |d qS ))rC   )r   rO   rN   rC   )r"   r1   r   r   r   r%     s   c             S   s   g | ]}t d qS )rR   )r   rO   )r"   _r   r   r   r%     s    )r|   rB   r   )r   r   msvsvhatsr   r   r   r     s    
zAdam._create_all_weightsc          
   C   s  |  ||}g | _| j}| jdkrJ|dd| jt| jt	| j    }t
t| jdg t| jt }W d Q R X |tdt| j| dt| j|   }| |\}}}	xt|||||	D ]\}
}}}}| j| d| j |  }| j| d| j t|  }| jr\t||}|
|| t|| j   }| jt|| n|
|| t|| j   }| jt|| | jt|| |}t|
dd d k	r|
|}| jt|
| qW | jS )Nr   g      ?rR   rS   )r)   r   rE   rL   rG   r   rU   rB   r   rC   r   control_dependenciesr   rT   floatxrj   powr{   rz   r   r-   ri   r|   maximumrg   r/   rV   rW   rS   )r   r    r   r(   rE   tlr_tr~   r   r   r1   r#   rX   rY   Zvhatm_tv_tZvhat_tp_trZ   r   r   r   r!   	  s<    
$
zAdam.get_updatesc                sp   t t| jt t| jt t| jt t| j| j| jd}t	t
|  }tt| t|  S )N)rE   rz   r{   rG   rg   r|   )r[   r   r\   rE   rz   r{   rG   rg   r|   rH   ru   r6   r]   r^   r_   )r   r5   r`   )rJ   r   r   r6   1  s    
zAdam.get_config)rv   rw   rx   Nry   F)	r9   r:   r;   r<   r   r   r!   r6   ra   r   r   )rJ   r   ru     s        (ru   c                   s>   e Zd ZdZd fdd	Zd	d
 Zdd Z fddZ  ZS )Adamaxa  Adamax optimizer from Adam paper's Section 7.

  It is a variant of Adam based on the infinity norm.
  Default parameters follow those provided in the paper.

  Args:
    lr: float >= 0. Learning rate.
    beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
    epsilon: float >= 0. Fuzz factor.
      If `None`, defaults to `backend.epsilon()`.
    decay: float >= 0. Learning rate decay over each update.
  Mb`??+?N        c          	      s   t t| jf | t| jjX tjdddd| _tj|dd| _	tj|dd| _
tj|dd| _tj|d	d| _W d Q R X |d krt }|| _|| _d S )
Nr   rA   rB   )rC   rD   rE   )rD   rz   r{   rG   )rH   r   r   r   rI   rJ   r9   rK   rB   rE   rz   r{   rG   rg   rL   )r   rE   rz   r{   rg   rG   r   )rJ   r   r   r   L  s    zAdamax.__init__c             C   sD   dd |D }dd |D }dd |D }| j g| | | _||fS )Nc             S   s   g | ]}t |qS r   )r   rN   )r"   r1   r   r   r   r%   a  s    z.Adamax._create_all_weights.<locals>.<listcomp>c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%   c  s    c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%   e  s    )rB   r   )r   r   rP   r~   usr   r   r   r   _  s
    zAdamax._create_all_weightsc          
   C   sj  |  ||}g | _| j}| jdkrJ|dd| jt| jt	| j    }t
t| jdg t| jt }W d Q R X |dt| j|  }| |\}}xt||||D ]\}	}
}}| j| d| j |
  }t| j| t|
}|	|| || j   }| jt|| | jt|| |}t|	dd d k	rL|	|}| jt|	| qW | jS )Nr   g      ?rR   rS   )r)   r   rE   rL   rG   r   rU   rB   r   rC   r   r   r   rT   r   r   rz   r   r-   r   r{   absrg   r/   rV   rW   rS   )r   r    r   r(   rE   r   r   r~   r   r1   r#   rX   ur   Zu_tr   rZ   r   r   r   r!   i  s0    

zAdamax.get_updatesc                sl   t t| jt t| jt t| jt t| j| jd}tt	| 
 }tt| t|  S )N)rE   rz   r{   rG   rg   )r[   r   r\   rE   rz   r{   rG   rg   rH   r   r6   r]   r^   r_   )r   r5   r`   )rJ   r   r   r6     s    
zAdamax.get_config)r   r   r   Nr   )	r9   r:   r;   r<   r   r   r!   r6   ra   r   r   )rJ   r   r   >  s       
#r   c                   s>   e Zd ZdZd fdd	Zd	d
 Zdd Z fddZ  ZS )Nadama  Nesterov Adam optimizer.

  Much like Adam is essentially RMSprop with momentum,
  Nadam is Adam RMSprop with Nesterov momentum.

  Default parameters follow those provided in the paper.
  It is recommended to leave the parameters of this optimizer
  at their default values.

  Args:
    lr: float >= 0. Learning rate.
    beta_1/beta_2: floats, 0 < beta < 1. Generally close to 1.
    epsilon: float >= 0. Fuzz factor.
      If `None`, defaults to `backend.epsilon()`.
  Mb`??+?NMbp?c          	      s   t t| jf | t| jjX tjdddd| _tjddd| _	tj|dd| _
tj|d	d| _tj|d
d| _W d Q R X |d krt }|| _|| _d S )Nr   rA   rB   )rC   rD   g      ?
m_schedule)rD   rE   rz   r{   )rH   r   r   r   rI   rJ   r9   rK   rB   r   rE   rz   r{   rg   schedule_decay)r   rE   rz   r{   rg   r   r   )rJ   r   r   r     s    zNadam.__init__c             C   sH   dd |D }dd |D }dd |D }| j | jg| | | _||fS )Nc             S   s   g | ]}t |qS r   )r   rN   )r"   r1   r   r   r   r%     s    z-Nadam._create_all_weights.<locals>.<listcomp>c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%     s    c             S   s   g | ]}t |qS r   )r   rO   )r"   r.   r   r   r   r%     s    )rB   r   r   )r   r   rP   r~   r   r   r   r   r     s
    zNadam._create_all_weightsc          	   C   s  |  ||}g | _tt| jdg t| jt	
 }W d Q R X | jddtt	d|| j    }| jddtt	d|d | j    }| j| }| j| | }| j| j|f | |\}	}
xt|||	|
D ]\}}}}|d|  }| j| d| j |  }|d|  }| j| d| j t|  }|dt| j|  }d| | ||  }| jt|| | jt|| || j| t	|| j   }|}t|dd d k	r||}| jt|| qW | jS )NrR   g      ?g      ?gQ?rS   )r)   r   r   r   r   rT   rB   r   rU   r   r   rz   r   Zcast_to_floatxr   r   r/   r   r-   r{   ri   rV   rE   rj   rg   rW   rS   )r   r    r   r(   r   Zmomentum_cache_tZmomentum_cache_t_1Zm_schedule_newZm_schedule_nextr~   r   r1   r#   rX   rY   Zg_primer   Z	m_t_primer   Z	v_t_primeZm_t_barr   rZ   r   r   r   r!     s>    

zNadam.get_updatesc                sb   t t| jt t| jt t| j| j| jd}tt	| 
 }tt| t|  S )N)rE   rz   r{   rg   r   )r[   r   r\   rE   rz   r{   rg   r   rH   r   r6   r]   r^   r_   )r   r5   r`   )rJ   r   r   r6     s    
zNadam.get_config)r   r   r   Nr   )	r9   r:   r;   r<   r   r   r!   r6   ra   r   r   )rJ   r   r     s       ,r   c               @   s`   e Zd ZdZdddZdd ZdddZd	d
 Zdd Zdd Z	e
dd Zdd Zdd ZdS )TFOptimizerz/Wrapper class for native TensorFlow optimizers.Nc          	   C   sd   || _ | j|dd |d krJt| jj tjdddd| _W d Q R X n|| _| j| jdd d S )N	optimizer)rD   r   rA   rB   )rC   rD   global_step)r   Z_track_trackabler   rI   rJ   r9   rK   rB   )r   r   rB   r   r   r   r     s    zTFOptimizer.__init__c             C   s   |S )zBClip gradients according to the clipnorm and clipvalue attributes.r   )r   r(   r   r   r   _clip_gradients	  s    zTFOptimizer._clip_gradientsc          	   C   s   t |s|dkrtd|dk	r$|nt }t |rj|, t |sL|| | }t |r`| }W dQ R X t|}|r||||}tt	||}| 
| dS )z&Mimics the `OptimizerV2.minimize` API.Nz2`tape` is required when a `Tensor` loss is passed.)callabler   r   ZGradientTapewatchr   flattenZgradientr^   r-   apply_gradients)r   r    Zvar_listZ	grad_lossZtaper(   grads_and_varsr   r   r   minimize  s    

zTFOptimizer.minimizec             C   s   | j j|| jd d S )N)r   )r   r   rB   )r   r   r   r   r   r   "  s    zTFOptimizer.apply_gradientsc             C   s   | j ||S )N)r   compute_gradients)r   r    r   r   r   r   	get_grads%  s    zTFOptimizer.get_gradsc             C   s   t  rFg | _|s | j|}n| j||}t }| j||}nB|sbt	| j
dg| _| jS g | _| j||}| jj|| j
d}| j| | jS )NrR   )r   )r   Zhas_strategyr   r   r   r
   Zget_global_stepr   r   rT   rB   r/   )r   r    r   r(   r   Z
opt_updater   r   r   r!   (  s     zTFOptimizer.get_updatesc             C   s   t d S )N)r   )r   r   r   r   r   D  s    zTFOptimizer.weightsc             C   s   t d S )N)r   )r   r   r   r   r6   H  s    zTFOptimizer.get_configc             C   s   t d S )N)r   )r   r5   r   r   r   r8   K  s    zTFOptimizer.from_config)N)NN)r9   r:   r;   r<   r   r   r   r   r   r!   propertyr   r6   r8   r   r   r   r   r     s   


r   N)&r<   Ztensorflow.python.distributer   Ztensorflow.python.eagerr   Ztensorflow.python.frameworkr   Ztensorflow.python.kerasr   Ztensorflow.python.opsr   r   r   Ztensorflow.python.trackabler	   Z	trackableZtensorflow.python.trainingr
   Ztensorflow.python.utilr   objectr   r>   rb   rm   rp   ru   r   r   Z	Trackabler   ZsgdZrmspropZadagradZadadeltaZadamZadamaxZnadamr   r   r   r   <module>   s4   FGJ\jZdU