# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training state management."""

import os
from tensorflow.python.checkpoint import checkpoint as trackable_util
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.keras import backend
from tensorflow.python.keras.distribute import distributed_file_utils
from tensorflow.python.keras.utils import mode_keys
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables

# Constant for `tf.keras.Model` attribute to store the epoch at which the most
# recently saved checkpoint was saved.
CKPT_SAVED_EPOCH = '_ckpt_saved_epoch'

CKPT_SAVED_EPOCH_UNUSED_VALUE = -1


class WorkerTrainingState(object):
  """Training state management class.

  This class provides apis for backing up and restoring the training state.
  This allows model and epoch information to be saved periodically and restore
  for fault-tolerance, also known as preemption-recovery purpose.
  """

  def __init__(self, model, checkpoint_dir):
    self._model = model

    # The epoch at which the checkpoint is saved. Used for fault-tolerance.
    # GPU device only has int64 dtype registered VarHandleOp.
    self._ckpt_saved_epoch = variables.Variable(
        initial_value=constant_op.constant(
            CKPT_SAVED_EPOCH_UNUSED_VALUE, dtype=dtypes.int64),
        name='ckpt_saved_epoch')

    # Variable initialization.
    backend.set_value(self._ckpt_saved_epoch, CKPT_SAVED_EPOCH_UNUSED_VALUE)

    # _ckpt_saved_epoch gets tracked and is included in the checkpoint file
    # when backing up.
    checkpoint = trackable_util.Checkpoint(
        model=self._model, ckpt_saved_epoch=self._ckpt_saved_epoch)

    # If this is single-worker training, checkpoint_dir are the same for
    # write_checkpoint_manager and read_checkpoint_manager.
    #
    # If this is multi-worker training, and this worker should not
    # save checkpoint, we replace the write_checkpoint_manager's checkpoint_dir
    # with a temp filepath, so it writes to a file that will be removed at the
    # end of back_up() call. This is necessary because the SyncOnReadVariable
    # needs to be synced across all the workers in order to be read, and all
    # workers need to perform `save()`.
    # But all workers should restore from the same checkpoint_dir as passed in
    # read_checkpoint_manager.
    self.read_checkpoint_manager = checkpoint_management.CheckpointManager(
        checkpoint,
        directory=os.path.join(checkpoint_dir, 'chief'),
        max_to_keep=1)
    write_checkpoint_dir = distributed_file_utils.write_dirpath(
        checkpoint_dir, self._model.distribute_strategy)
    if self._model.distribute_strategy.extended.should_checkpoint:
      self.write_checkpoint_manager = self.read_checkpoint_manager
    else:
      self.write_checkpoint_manager = checkpoint_management.CheckpointManager(
          checkpoint, directory=write_checkpoint_dir, max_to_keep=1)

  def back_up(self, epoch):
    """Back up the current state of training into a checkpoint file.

    Args:
      epoch: The current epoch information to be saved.
    """
    backend.set_value(self._ckpt_saved_epoch, epoch)
    # Save the model plus CKPT_SAVED_EPOCH variable.
    if self.write_checkpoint_manager.save():
      distributed_file_utils.remove_temp_dirpath(
          self.write_checkpoint_manager.directory,
          self._model.distribute_strategy)

  def restore(self):
    """Restore the training state from the backed up checkpoint file.

    Returns:
      True if the training state is successfully restored. False if the training
      state doesn't need to be restored, or error occurred so it can't.
    """
    self.read_checkpoint_manager.restore_or_initialize()

  def delete_backup(self):
    """Delete the backup directories.

    Delete the backup directories which should not exist after `fit()`
    successfully finishes.
    """
    if self.write_checkpoint_manager is self.read_checkpoint_manager:
      try:
        file_io.delete_recursively_v2(self.write_checkpoint_manager.directory)
      except errors.NotFoundError:
        pass

  def maybe_load_initial_epoch_from_ckpt(self, initial_epoch, mode):
    """Maybe load initial epoch from ckpt considering possible worker recovery.

    When `_ckpt_saved_epoch` attribute exists and is not
    `CKPT_SAVED_EPOCH_UNUSED_VALUE`, this is under multi-worker training setting
    and indicates the worker is recovering from previous failure. In this case,
    infer `initial_epoch` from `self._ckpt_saved_epoch` to continue previous
    unfinished training from certain epoch.

    Args:
      initial_epoch: The original initial_epoch user passes in in `fit()`.
      mode: The mode for running `model.fit()`.

    Returns:
      If the training is recovering from previous failure under multi-worker
      training setting, return the epoch the training is supposed to continue
      at. Otherwise, return the `initial_epoch` the user passes in.
    """

    epoch = backend.eval(self._ckpt_saved_epoch)
    if mode == mode_keys.ModeKeys.TRAIN and epoch >= 0:
      # The most recently saved epoch is one epoch prior to the epoch it
      # failed at, so return the value of 'self._ckpt_saved_epoch' plus one.
      return epoch + 1
    return initial_epoch
