# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base test class for checkpointing datasets."""

import os

import numpy as np
from tensorflow.python.checkpoint import checkpoint as tracking_util
from tensorflow.python.checkpoint import checkpoint_management
from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import options as options_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util import nest


def remove_variants(get_next_op):
  # TODO(b/72408568): Remove this once session.run can get variant tensors.
  """Remove variants from a nest structure, so sess.run will execute."""

  def _remove_variant(x):
    if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
      return ()
    else:
      return x

  return nest.map_structure(_remove_variant, get_next_op)


def default_test_combinations():
  """Returns the default test combinations for testing checkpointing."""

  def disable_optimizations(ds_fn):
    options = options_lib.Options()
    options.experimental_optimization.apply_default_optimizations = False

    def ds_fn_no_opt():
      return ds_fn().with_options(options)

    return ds_fn_no_opt

  def verify_unused_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
    obj.verify_unused_iterator(
        ds_fn=disable_optimizations(ds_fn=ds_fn),
        num_outputs=num_outputs,
        sparse_tensors=sparse_tensors)

  verify_unused_iterator_combination = combinations.combine(
      verify_fn=combinations.NamedObject("verify_unused_iterator",
                                         verify_unused_iterator))

  def verify_fully_used_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
    obj.verify_fully_used_iterator(
        ds_fn=disable_optimizations(ds_fn=ds_fn),
        num_outputs=num_outputs,
        sparse_tensors=sparse_tensors)

  verify_fully_used_iterator_combination = combinations.combine(
      verify_fn=combinations.NamedObject("verify_fully_used_iterator",
                                         verify_fully_used_iterator))

  def verify_exhausted_iterator(obj, ds_fn, num_outputs, sparse_tensors=False):
    obj.verify_exhausted_iterator(
        ds_fn=disable_optimizations(ds_fn=ds_fn),
        num_outputs=num_outputs,
        sparse_tensors=sparse_tensors)

  verify_exhausted_iterator_combination = combinations.combine(
      verify_fn=combinations.NamedObject("verify_exhausted_iterator",
                                         verify_exhausted_iterator))

  def verify_multiple_breaks(obj, ds_fn, num_outputs, sparse_tensors=False):
    obj.verify_multiple_breaks(
        ds_fn=disable_optimizations(ds_fn=ds_fn),
        num_outputs=num_outputs,
        sparse_tensors=sparse_tensors)

  verify_multiple_breaks_combination = combinations.combine(
      verify_fn=combinations.NamedObject("verify_multiple_breaks",
                                         verify_multiple_breaks))

  def verify_reset_restored_iterator(obj,
                                     ds_fn,
                                     num_outputs,
                                     sparse_tensors=False):
    obj.verify_reset_restored_iterator(
        ds_fn=disable_optimizations(ds_fn=ds_fn),
        num_outputs=num_outputs,
        sparse_tensors=sparse_tensors)

  verify_reset_restored_iterator_combination = combinations.combine(
      verify_fn=combinations.NamedObject("verify_reset_restored_iterator",
                                         verify_reset_restored_iterator))

  return (verify_unused_iterator_combination +
          verify_fully_used_iterator_combination +
          verify_exhausted_iterator_combination +
          verify_multiple_breaks_combination +
          verify_reset_restored_iterator_combination)


# TODO(b/72657739): Remove sparse_tensor argument, which is to test the
# (deprecated) saveable `SparseTensorSliceDataset`, once the API
# `from_sparse_tensor_slices()` and related tests are deleted.
class CheckpointTestBase(test.TestCase):
  """Base test class for checkpointing datasets."""

  def tearDown(self):
    self._delete_ckpt()
    super(CheckpointTestBase, self).tearDown()

  def verify_unused_iterator(self,
                             ds_fn,
                             num_outputs,
                             sparse_tensors=False,
                             verify_exhausted=True):
    """Verifies that saving and restoring an unused iterator works.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      num_outputs: Total number of outputs expected from this Dataset.
      sparse_tensors: Whether dataset is built from SparseTensor(s).
      verify_exhausted: Whether to verify that the iterator has been exhausted
        after producing `num_outputs` elements.

    Raises:
      AssertionError if any test fails.
    """
    self.verify_run_with_breaks(
        ds_fn, [0],
        num_outputs,
        sparse_tensors=sparse_tensors,
        verify_exhausted=verify_exhausted)

  def verify_fully_used_iterator(self,
                                 ds_fn,
                                 num_outputs,
                                 sparse_tensors=False):
    """Verifies that saving and restoring a fully used iterator works.

    Note that this only checks saving and restoring an iterator from which
    `num_outputs` items have been produced but does not check for an
    exhausted iterator, i.e., one from which an OutOfRange error has been
    returned.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      num_outputs: Total number of outputs expected from this Dataset.
      sparse_tensors: Whether dataset is built from SparseTensor(s).

    Raises:
      AssertionError if test fails.
    """
    self.verify_run_with_breaks(
        ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)

  def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
    """Verifies that saving and restoring an exhausted iterator works.

    An exhausted iterator is one which has returned an OutOfRange error.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      num_outputs: Total number of outputs expected from this Dataset.
      sparse_tensors: Whether dataset is built from SparseTensor(s).

    Raises:
      AssertionError if any test fails.
    """
    self.gen_outputs(
        ds_fn, [],
        num_outputs,
        verify_exhausted=True,
        sparse_tensors=sparse_tensors)
    actual = self.gen_outputs(
        ds_fn, [],
        0,
        ckpt_saved=True,
        verify_exhausted=True,
        sparse_tensors=sparse_tensors)
    self.assertLen(actual, 0)

  def verify_multiple_breaks(self,
                             ds_fn,
                             num_outputs,
                             num_breaks=10,
                             sparse_tensors=False,
                             verify_exhausted=True):
    """Attempts to save/restore at multiple break points.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      num_outputs: Total number of outputs expected from this Dataset.
      num_breaks: The number of break points. These are uniformly spread in [0,
        num_outputs] both inclusive.
      sparse_tensors: Whether dataset is built from SparseTensor(s).
      verify_exhausted: Whether to verify that the iterator has been exhausted
        after producing `num_outputs` elements.

    Raises:
      AssertionError if any test fails.
    """
    self.verify_run_with_breaks(
        ds_fn,
        self.gen_break_points(num_outputs, num_breaks),
        num_outputs,
        sparse_tensors=sparse_tensors,
        verify_exhausted=verify_exhausted)

  def verify_reset_restored_iterator(self,
                                     ds_fn,
                                     num_outputs,
                                     break_point=None,
                                     sparse_tensors=False,
                                     verify_exhausted=True):
    """Attempts to re-initialize a restored iterator.

    This is useful when restoring a training checkpoint during validation.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      num_outputs: Total number of outputs expected from this Dataset.
      break_point: Break point. Optional. Defaults to num_outputs/2.
      sparse_tensors: Whether dataset is built from SparseTensor(s).
      verify_exhausted: Whether to verify that the iterator has been exhausted
        after producing `num_outputs` elements.

    Raises:
      AssertionError if any test fails.
    """
    if context.executing_eagerly():
      self.skipTest("Eager mode iteration do not support re-initialization.")

    break_point = num_outputs // 2 if not break_point else break_point

    # Collect ground truth containing all outputs.
    expected = self.gen_outputs(
        ds_fn, [],
        num_outputs,
        sparse_tensors=sparse_tensors,
        verify_exhausted=verify_exhausted)

    # Skip some items and save checkpoint.
    self.gen_outputs(
        ds_fn, [],
        break_point,
        sparse_tensors=sparse_tensors,
        verify_exhausted=False)

    actual = []
    # Restore from checkpoint and then run init_op.
    with ops.Graph().as_default() as g:
      saver = self._import_meta_graph()
      init_op, get_next_op = self._get_iterator_ops_from_collection(
          ds_fn, sparse_tensors=sparse_tensors)
      get_next_op = remove_variants(get_next_op)
      with self.session(graph=g) as sess:
        self._initialize(init_op, sess)
        self._restore(saver, sess)
        self._initialize(init_op, sess)
        for _ in range(num_outputs):
          actual.append(sess.run(get_next_op))
        if verify_exhausted:
          with self.assertRaises(errors.OutOfRangeError):
            sess.run(get_next_op)
    self.match(expected, actual)

  def verify_error_on_save(self,
                           ds_fn,
                           num_outputs,
                           error,
                           break_point=None,
                           sparse_tensors=False):
    """Attempts to save a non-saveable iterator.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      num_outputs: Total number of outputs expected from this Dataset.
      error: Declared error when trying to save iterator.
      break_point: Break point. Optional. Defaults to num_outputs/2.
      sparse_tensors: Whether dataset is built from SparseTensor(s).

    Raises:
      AssertionError if any test fails.
    """
    break_point = num_outputs // 2 if not break_point else break_point
    if context.executing_eagerly():
      iterator = iter(ds_fn())
      ckpt = tracking_util.Checkpoint(iterator=iterator)
      for _ in range(break_point):
        next(iterator)
      with self.assertRaises(error):
        ckpt.save(self._ckpt_path())
    else:
      with ops.Graph().as_default() as g:
        init_op, get_next_op, saver = self._build_graph(
            ds_fn, sparse_tensors=sparse_tensors)
        get_next_op = remove_variants(get_next_op)
        with self.session(graph=g) as sess:
          self._initialize(init_op, sess)
          for _ in range(break_point):
            sess.run(get_next_op)
          with self.assertRaises(error):
            self._save(sess, saver)

  def verify_run_with_breaks(self,
                             ds_fn,
                             break_points,
                             num_outputs,
                             sparse_tensors=False,
                             verify_exhausted=True):
    """Verifies that ds_fn() produces the same outputs with and without breaks.

    1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
       *without* stopping at break points.
    2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
       with stopping at break points.

    Deep matches outputs from 1 and 2.

    Args:
      ds_fn: 0-argument function that returns a Dataset.
      break_points: A list of integers. For each `break_point` in
        `break_points`, we produce outputs till `break_point` number of items
        have been produced and then checkpoint the state. The current graph and
        session are destroyed and a new graph and session are used to produce
        outputs till next checkpoint or till `num_outputs` elements have been
        produced. `break_point` must be <= `num_outputs`.
      num_outputs: Total number of outputs expected from this Dataset.
      sparse_tensors: Whether dataset is built from SparseTensor(s).
      verify_exhausted: Whether to verify that the iterator has been exhausted
        after producing `num_outputs` elements.

    Raises:
      AssertionError if any test fails.
    """
    expected = self.gen_outputs(
        ds_fn, [],
        num_outputs,
        sparse_tensors=sparse_tensors,
        verify_exhausted=verify_exhausted)

    actual = self.gen_outputs(
        ds_fn,
        break_points,
        num_outputs,
        sparse_tensors=sparse_tensors,
        verify_exhausted=verify_exhausted)

    self.match(expected, actual)

  def gen_outputs(self,
                  ds_fn,
                  break_points,
                  num_outputs,
                  ckpt_saved=False,
                  sparse_tensors=False,
                  verify_exhausted=True,
                  save_checkpoint_at_end=True):
    """Generates elements from input dataset while stopping at break points.

    Produces `num_outputs` outputs and saves the state of the iterator in the
    Saver checkpoint.

    Args:
      ds_fn: 0-argument function that returns the dataset.
      break_points: A list of integers. For each `break_point` in
        `break_points`, we produce outputs till `break_point` number of items
        have been produced and then checkpoint the state. The current graph and
        session are destroyed and a new graph and session are used to produce
        outputs till next checkpoint or till `num_outputs` elements have been
        produced. `break_point` must be <= `num_outputs`.
      num_outputs: The total number of outputs to produce from the iterator.
      ckpt_saved: Whether a checkpoint already exists.
      sparse_tensors:  Whether dataset is built from SparseTensor(s).
      verify_exhausted: Whether to verify that the iterator has been exhausted
        after producing `num_outputs` elements.
      save_checkpoint_at_end: Whether to save a checkpoint after producing all
        outputs. If False, checkpoints are saved each break point but not at the
        end. Note that checkpoints overwrite each other so there is always only
        a single checkpoint available. Defaults to True.

    Returns:
      A list of `num_outputs` items.
    """
    outputs = []

    if context.executing_eagerly():
      for i in range(len(break_points) + 1):
        iterator = iter(ds_fn())
        ckpt = tracking_util.Checkpoint(iterator=iterator)
        if ckpt_saved:
          ckpt_path = self._latest_ckpt()
          ckpt.restore(ckpt_path)
        start = break_points[i - 1] if i > 0 else 0
        end = break_points[i] if i < len(break_points) else num_outputs
        num_iters = end - start
        for _ in range(num_iters):
          outputs.append(self.evaluate(next(iterator)))
        if i == len(break_points) and verify_exhausted:
          with self.assertRaises(StopIteration):
            next(iterator)
        if save_checkpoint_at_end or i < len(break_points):
          ckpt_path = ckpt.save(self._ckpt_path())
          ckpt_saved = True
    else:
      def get_ops():
        if ckpt_saved:
          saver = self._import_meta_graph()
          init_op, get_next_op = self._get_iterator_ops_from_collection(
              ds_fn, sparse_tensors=sparse_tensors)
        else:
          init_op, get_next_op, saver = self._build_graph(
              ds_fn, sparse_tensors=sparse_tensors)
        return init_op, get_next_op, saver

      for i in range(len(break_points) + 1):
        with ops.Graph().as_default() as g:
          init_op, get_next_op, saver = get_ops()
          get_next_op = remove_variants(get_next_op)
          with self.session(graph=g) as sess:
            if ckpt_saved:
              self._initialize(init_op, sess)
              self._restore(saver, sess)
            else:
              self._initialize(init_op, sess)
            start = break_points[i - 1] if i > 0 else 0
            end = break_points[i] if i < len(break_points) else num_outputs
            num_iters = end - start
            for _ in range(num_iters):
              outputs.append(sess.run(get_next_op))
            if i == len(break_points) and verify_exhausted:
              with self.assertRaises(errors.OutOfRangeError):
                sess.run(get_next_op)
            if save_checkpoint_at_end or i < len(break_points):
              self._save(sess, saver)
              ckpt_saved = True

    return outputs

  def match(self, expected, actual):
    """Matches nested structures.

    Recursively matches shape and values of `expected` and `actual`.
    Handles scalars, numpy arrays and other python sequence containers
    e.g. list, dict, as well as SparseTensorValue and RaggedTensorValue.

    Args:
      expected: Nested structure 1.
      actual: Nested structure 2.

    Raises:
      AssertionError if matching fails.
    """
    if isinstance(expected, np.ndarray):
      expected = expected.tolist()
    if isinstance(actual, np.ndarray):
      actual = actual.tolist()
    self.assertEqual(type(expected), type(actual))

    if nest.is_nested(expected):
      self.assertEqual(len(expected), len(actual))
      if isinstance(expected, dict):
        for key1, key2 in zip(sorted(expected), sorted(actual)):
          self.assertEqual(key1, key2)
          self.match(expected[key1], actual[key2])
      else:
        for item1, item2 in zip(expected, actual):
          self.match(item1, item2)
    elif isinstance(expected, sparse_tensor.SparseTensorValue):
      self.match((expected.indices, expected.values, expected.dense_shape),
                 (actual.indices, actual.values, actual.dense_shape))
    elif isinstance(expected, ragged_tensor_value.RaggedTensorValue):
      self.match((expected.values, expected.row_splits),
                 (actual.values, actual.row_splits))
    else:
      self.assertEqual(expected, actual)

  def does_not_match(self, expected, actual):
    with self.assertRaises(AssertionError):
      self.match(expected, actual)

  def gen_break_points(self, num_outputs, num_samples=10):
    """Generates `num_samples` breaks points in [0, num_outputs]."""
    return np.linspace(0, num_outputs, num_samples, dtype=int)

  def _build_graph(self, ds_fn, sparse_tensors=False):
    dataset = ds_fn()
    iterator = dataset_ops.make_initializable_iterator(dataset)
    external_state_policy = dataset.options().experimental_external_state_policy
    saveable = contrib_iterator_ops.make_saveable_from_iterator(
        iterator, external_state_policy=external_state_policy)
    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
    init_op = iterator.initializer
    if sparse_tensors:
      get_next = sparse_tensor.SparseTensor(*iterator.get_next())
    else:
      get_next = iterator.get_next()
    self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
                                         sparse_tensors)
    saver = saver_lib.Saver(allow_empty=True)
    return init_op, get_next, saver

  def _add_iterator_ops_to_collection(self,
                                      init_op,
                                      get_next,
                                      ds_fn,
                                      sparse_tensors=False):
    ops.add_to_collection("iterator_ops", init_op)
    # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
    # do not support tuples we flatten the tensors and restore the shape in
    # `_get_iterator_ops_from_collection`.
    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
      ops.add_to_collection("iterator_ops", get_next.indices)
      ops.add_to_collection("iterator_ops", get_next.values)
      ops.add_to_collection("iterator_ops", get_next.dense_shape)
      return

    get_next_list = nest.flatten(get_next)
    for i, output_class in enumerate(
        nest.flatten(self._get_output_classes(ds_fn))):
      if output_class is sparse_tensor.SparseTensor:
        ops.add_to_collection("iterator_ops", get_next_list[i].indices)
        ops.add_to_collection("iterator_ops", get_next_list[i].values)
        ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
      else:
        ops.add_to_collection("iterator_ops", get_next_list[i])

  def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
    all_ops = ops.get_collection("iterator_ops")
    if sparse_tensors:  # specific for deprecated `from_sparse_tensor_slices`.
      init_op, indices, values, dense_shape = all_ops
      return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
    get_next_list = []
    i = 1
    for output_class in nest.flatten(self._get_output_classes(ds_fn)):
      if output_class is sparse_tensor.SparseTensor:
        indices, values, dense_shape = all_ops[i:i + 3]
        i += 3
        get_next_list.append(
            sparse_tensor.SparseTensor(indices, values, dense_shape))
      else:
        get_next_list.append(all_ops[i])
        i += 1
    return all_ops[0], nest.pack_sequence_as(
        self._get_output_types(ds_fn), get_next_list)

  def _get_output_types(self, ds_fn):
    assert not context.executing_eagerly()
    with ops.Graph().as_default():
      return dataset_ops.get_legacy_output_types(ds_fn())

  def _get_output_shapes(self, ds_fn):
    assert not context.executing_eagerly()
    with ops.Graph().as_default():
      return dataset_ops.get_legacy_output_shapes(ds_fn())

  def _get_output_classes(self, ds_fn):
    assert not context.executing_eagerly()
    with ops.Graph().as_default():
      return dataset_ops.get_legacy_output_classes(ds_fn())

  def _ckpt_path(self):
    return os.path.join(self.get_temp_dir(), "iterator")

  def _latest_ckpt(self):
    return checkpoint_management.latest_checkpoint(self.get_temp_dir())

  def _save(self, sess, saver):
    saver.save(sess, self._ckpt_path())

  def _restore(self, saver, sess):
    sess.run(lookup_ops.tables_initializer())
    saver.restore(sess, self._latest_ckpt())

  def _initialize(self, init_op, sess):
    sess.run(variables.global_variables_initializer())
    sess.run(lookup_ops.tables_initializer())
    sess.run(init_op)

  def _import_meta_graph(self):
    meta_file_path = self._ckpt_path() + ".meta"
    return saver_lib.import_meta_graph(meta_file_path)

  def _delete_ckpt(self):
    # Remove all checkpoint files.
    prefix = self._ckpt_path()
    pattern = prefix + "*"
    files = gfile.Glob(pattern)
    map(gfile.Remove, files)
