# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrapper for prefetching_ops."""
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import options as options_lib
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.framework import type_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import resource_variable_ops


class _PerDeviceGenerator(dataset_ops.DatasetV2):
  """A `dummy` generator dataset."""

  def __init__(self, shard_num, multi_device_iterator_resource, incarnation_id,
               source_device, element_spec, iterator_is_anonymous):
    self._element_spec = element_spec

    multi_device_iterator_string_handle = (
        gen_dataset_ops.multi_device_iterator_to_string_handle(
            multi_device_iterator_resource))

    # TODO(b/124254153): Enable autograph once the overhead is low enough.
    @function.defun(autograph=False)  # Pure graph code.
    def _init_func():
      return multi_device_iterator_string_handle

    init_func_concrete = _init_func.get_concrete_function()

    # TODO(b/124254153): Enable autograph once the overhead is low enough.
    @function.defun(autograph=False)  # Pure graph code.
    def _remote_init_func():
      return functional_ops.remote_call(
          target=source_device,
          args=init_func_concrete.captured_inputs,
          Tout=[dtypes.string],
          f=init_func_concrete)

    self._init_func = _remote_init_func.get_concrete_function()
    self._init_captured_args = self._init_func.captured_inputs

    # TODO(b/124254153): Enable autograph once the overhead is low enough.
    @function.defun(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        autograph=False)  # Pure graph code.
    def _next_func(string_handle):
      # pylint: disable=protected-access
      multi_device_iterator = (
          gen_dataset_ops.multi_device_iterator_from_string_handle(
              string_handle=string_handle,
              output_types=structure.get_flat_tensor_types(self._element_spec),
              output_shapes=structure.get_flat_tensor_shapes(
                  self._element_spec)))
      return gen_dataset_ops.multi_device_iterator_get_next_from_shard(
          multi_device_iterator=multi_device_iterator,
          shard_num=shard_num,
          incarnation_id=incarnation_id,
          output_types=structure.get_flat_tensor_types(self._element_spec),
          output_shapes=structure.get_flat_tensor_shapes(self._element_spec))

    next_func_concrete = _next_func.get_concrete_function()

    # TODO(b/124254153): Enable autograph once the overhead is low enough.
    @function.defun_with_attributes(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        attributes={"experimental_ints_on_device": True},
        autograph=False)  # Pure graph code.
    def _remote_next_func(string_handle):
      return_values = functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + next_func_concrete.captured_inputs,
          Tout=structure.get_flat_tensor_types(self._element_spec),
          f=next_func_concrete)
      # Add full type information to the graph so that the RemoteCall op
      # can determine for each of its outputs whether or not they are ragged
      # tensors (or other types that use variants) that contain strings
      # (or other host memory types). Then RemoteCall can
      # appropriately set AllocatorAttributes to control copies so
      # strings/host memory types stay on CPU.
      fulltype_list = type_utils.fulltypes_for_flat_tensors(self._element_spec)
      fulltype = type_utils.fulltype_list_to_product(fulltype_list)
      for return_value in return_values:
        return_value.op.experimental_set_type(fulltype)
      return return_values

    self._next_func = _remote_next_func.get_concrete_function()
    self._next_captured_args = self._next_func.captured_inputs

    if iterator_is_anonymous:
      self._next_captured_args = self._next_captured_args + [
          multi_device_iterator_resource
      ]

    self._incarnation_id_index = -1
    for i, arg in enumerate(self._next_captured_args):
      if arg is incarnation_id:
        self._incarnation_id_index = i

    # TODO(b/124254153): Enable autograph once the overhead is low enough.
    @function.defun(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        autograph=False)  # Pure graph code.
    def _finalize_func(unused_string_handle):
      return array_ops.constant(0, dtypes.int64)

    finalize_func_concrete = _finalize_func.get_concrete_function()

    # TODO(b/124254153): Enable autograph once the overhead is low enough.
    @function.defun(
        input_signature=[tensor_spec.TensorSpec([], dtypes.string)],
        autograph=False)  # Pure graph code.
    def _remote_finalize_func(string_handle):
      return functional_ops.remote_call(
          target=source_device,
          args=[string_handle] + finalize_func_concrete.captured_inputs,
          Tout=[dtypes.int64],
          f=finalize_func_concrete)

    self._finalize_func = _remote_finalize_func.get_concrete_function()
    self._finalize_captured_args = self._finalize_func.captured_inputs

    variant_tensor = gen_dataset_ops.generator_dataset(
        self._init_captured_args,
        self._next_captured_args,
        self._finalize_captured_args,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **self._flat_structure)
    super(_PerDeviceGenerator, self).__init__(variant_tensor)

  def _inputs(self):
    # TODO(b/116506223): Determine which datasets should be used as inputs here.
    return []

  @property
  def element_spec(self):
    return self._element_spec


class _ReincarnatedPerDeviceGenerator(dataset_ops.DatasetV2):
  """Creates a _PerDeviceGenerator-like dataset with a new incarnation_id.

  Re-uses the functions from the provided per_device_dataset and just switches
  out the function argument corresponding to the incarnation_id.
  """

  def __init__(self, per_device_dataset, incarnation_id):
    # pylint: disable=protected-access
    self._element_spec = per_device_dataset.element_spec
    self._init_func = per_device_dataset._init_func
    self._init_captured_args = self._init_func.captured_inputs

    self._next_func = per_device_dataset._next_func
    self._next_captured_args = per_device_dataset._next_captured_args
    # The captured arguments to the next_func are string_handle, incarnation_id.
    # We update the incarnation id to the new one.
    self._next_captured_args[
        per_device_dataset._incarnation_id_index] = incarnation_id

    self._finalize_func = per_device_dataset._finalize_func
    self._finalize_captured_args = per_device_dataset._finalize_captured_args

    variant_tensor = gen_dataset_ops.generator_dataset(
        self._init_captured_args,
        self._next_captured_args,
        self._finalize_captured_args,
        init_func=self._init_func,
        next_func=self._next_func,
        finalize_func=self._finalize_func,
        **self._flat_structure)
    super(_ReincarnatedPerDeviceGenerator, self).__init__(variant_tensor)

  def _inputs(self):
    # TODO(b/116506223): Determine which datasets should be used as inputs here.
    return []

  @property
  def element_spec(self):
    return self._element_spec


def _create_device_dataset(prototype_ds, incarnation_id, prefetch_buffer_size,
                           experimental_slack):
  """Uses _prototype_device_datasets[i] to build a dataset for the device."""
  ds = _ReincarnatedPerDeviceGenerator(prototype_ds, incarnation_id)
  if prefetch_buffer_size > 0:
    if experimental_slack:
      ds = dataset_ops.PrefetchDataset(ds, prefetch_buffer_size, slack_period=1)
    else:
      ds = ds.prefetch(prefetch_buffer_size)
  return ds


class MultiDeviceIterator:
  """An iterator over multiple devices."""

  def __init__(self,
               dataset,
               devices,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0"):
    """Constructs a MultiDeviceIterator.

    Args:
      dataset: The input dataset to be iterated over.
      devices: The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
        prefetch into.
      source_device: The host device to place the `dataset` on.  In order to
        prevent deadlocks, if the prefetch_buffer_size is greater than the
        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
    """
    options = options_lib.Options()
    options.experimental_distribute.num_devices = len(devices)
    # If `prefetch_buffer_size` is 0, we turn off the `inject_prefetch`
    # optimization to prevent potentially introducing asynchrony.
    if prefetch_buffer_size == 0:
      options.experimental_optimization.inject_prefetch = False
    dataset = dataset.with_options(options)
    self._dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
    self._experimental_slack = dataset.options().experimental_slack
    self._devices = devices
    self._source_device = source_device
    self._source_device_tensor = ops.convert_to_tensor(source_device)
    self._max_buffer_size = max_buffer_size
    self._prefetch_buffer_size = prefetch_buffer_size

    if self._prefetch_buffer_size > self._max_buffer_size:
      self._max_buffer_size = self._prefetch_buffer_size

    # Create the MultiDeviceIterator.
    with ops.device(self._source_device):
      # TODO(b/121378567): Get rid of this shared_name hack.
      shared_name = ""
      if context.executing_eagerly():
        shared_name = context.anonymous_name()
      self._multi_device_iterator_resource = (
          gen_dataset_ops.multi_device_iterator(
              devices=self._devices,
              shared_name=shared_name,
              container="",
              **self._dataset._flat_structure))  # pylint: disable=protected-access
      if context.executing_eagerly():
        # Delete the resource when this object is deleted
        self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
            handle=self._multi_device_iterator_resource,
            handle_device=self._source_device)

      # The incarnation ID is used to ensure consistency between the per-device
      # iterators and the multi-device iterator.
      self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
          self._dataset._variant_tensor,  # pylint: disable=protected-access
          self._multi_device_iterator_resource,
          max_buffer_size=self._max_buffer_size)

    self._prototype_device_datasets = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = _PerDeviceGenerator(
            i,
            self._multi_device_iterator_resource,
            self._incarnation_id,
            self._source_device_tensor,
            self._dataset.element_spec,
            iterator_is_anonymous=False)
        self._prototype_device_datasets.append(ds)

    # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
    # initialize the device side of the pipeline. This would allow the
    # MultiDeviceIterator to choose, for example, to move some transformations
    # into the device side from its input. It might be useful in rewriting.
    # Create the per device iterators.
    self._device_iterators = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = _create_device_dataset(self._prototype_device_datasets[i],
                                    self._incarnation_id,
                                    self._prefetch_buffer_size,
                                    self._experimental_slack)
        if context.executing_eagerly():
          self._device_iterators.append(dataset_ops.make_one_shot_iterator(ds))
        else:
          self._device_iterators.append(
              dataset_ops.make_initializable_iterator(ds))

    if not context.executing_eagerly():
      device_iterator_initializers = [
          iterator.initializer for iterator in self._device_iterators
      ]
      self._initializer = control_flow_ops.group(*device_iterator_initializers)

  def get_next(self, device=None):
    """Returns the next element given a `device`, else returns all in a list."""
    if device is not None:
      index = self._devices.index(device)
      return self._device_iterators[index].get_next()

    result = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        result.append(self._device_iterators[i].get_next())
    return result

  def get_next_as_optional(self):
    result = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        result.append(self._device_iterators[i].get_next_as_optional())
    return result

  @property
  def initializer(self):
    if context.executing_eagerly():
      return control_flow_ops.no_op()
    return self._initializer

  def _eager_reset(self):
    """Resets the MultiDeviceIterator in eager mode."""
    if not ops.executing_eagerly_outside_functions():
      raise ValueError(
          "Resetting a multi-device iterator is only supported in the eager "
          "mode.")
    # pylint: disable=protected-access
    self._incarnation_id = gen_dataset_ops.multi_device_iterator_init(
        self._dataset._variant_tensor,
        self._multi_device_iterator_resource,
        max_buffer_size=self._max_buffer_size)
    for i, device in enumerate(self._devices):
      with ops.device(device):
        ds = _create_device_dataset(self._prototype_device_datasets[i],
                                    self._incarnation_id,
                                    self._prefetch_buffer_size,
                                    self._experimental_slack)
        # Reset the device iterator resources with the new dataset.
        ds_variant = ds._variant_tensor
        gen_dataset_ops.make_iterator(
            ds_variant, self._device_iterators[i]._iterator_resource)

  @property
  def element_spec(self):
    return self._dataset.element_spec


class MultiDeviceIteratorSpec(type_spec.TypeSpec):
  """Type specification for `OwnedMultiDeviceIterator`."""

  __slots__ = ["_devices", "_source_device", "_element_spec"]

  def __init__(self, devices, source_device, element_spec):
    self._devices = devices
    self._source_device = source_device
    self._element_spec = element_spec

  @property
  def value_type(self):
    return OwnedMultiDeviceIterator

  def _serialize(self):
    return (tuple(self._devices), self._source_device, self._element_spec)

  @property
  def _component_specs(self):
    specs = [
        tensor_spec.TensorSpec([], dtypes.resource),
    ]
    for _ in range(len(self._devices)):
      specs.append(iterator_ops.IteratorSpec(self._element_spec))
    return specs

  def _to_components(self, value):
    # pylint: disable=protected-access
    c = [value._multi_device_iterator_resource]
    c.extend(value._device_iterators)
    return c

  def _from_components(self, components):
    return OwnedMultiDeviceIterator(
        dataset=None,
        devices=self._devices,
        source_device=self._source_device,
        components=components,
        element_spec=self._element_spec)

  @staticmethod
  def from_value(value):
    # pylint: disable=protected-access
    return MultiDeviceIteratorSpec(
        value._devices,
        value._source_device,
        value.element_spec)


class OwnedMultiDeviceIterator(composite_tensor.CompositeTensor):
  """An iterator over multiple devices.

  The multi-device iterator resource created through `OwnedMultiDeviceIterator`
  is owned by the Python object and the life time of the underlying resource is
  tied to the life time of the `OwnedMultiDeviceIterator` object. This makes
  `OwnedMultiDeviceIterator` appropriate for use in eager mode and inside of
  tf.functions.
  """

  def __init__(self,
               dataset=None,
               devices=None,
               max_buffer_size=1,
               prefetch_buffer_size=1,
               source_device="/cpu:0",
               components=None,
               element_spec=None):
    """Constructs an owned MultiDeviceIterator object.

    Args:
      dataset: The input dataset to be iterated over.
      devices: (Required.) The list of devices to fetch data to.
      max_buffer_size: Maximum size of the host side per device buffer to keep.
      prefetch_buffer_size: if > 0, then we setup a buffer on each device to
        prefetch into.
      source_device: The host device to place the `dataset` on.  In order to
        prevent deadlocks, if the prefetch_buffer_size is greater than the
        max_buffer_size, we set the max_buffer_size to prefetch_buffer_size.
      components: Tensor components to construct the MultiDeviceIterator from.
      element_spec: A (nested) structure of `tf.TypeSpec` objects that
        represents the type specification of elements of the iterator.

    Raises:
      RuntimeError: If executed in graph mode or outside of function building
        mode.
      ValueError: If any of the following happens:
        - `devices` is `None`
        - `dataset` is `None` and either `components` or `element_spec` is
          `None`
        - `dataset` is not None and either `components` or `element_spec` is
          provided
    """
    if not context.executing_eagerly() and not ops.inside_function():
      raise RuntimeError("OwnedMultiDeviceIterator is only supported inside of "
                         "tf.function or when eager execution is enabled.")
    if devices is None:
      raise ValueError("`devices` must be provided.")

    if dataset is None:
      if (components is None or element_spec is None):
        raise ValueError(
            "When `dataset` is not provided, both `components` and "
            "`element_spec` must be specified.")
      self._element_spec = element_spec
      self._devices = devices
      self._source_device = source_device
      self._multi_device_iterator_resource = components[0]
      self._device_iterators = components[1:]
    else:
      if (components is not None or element_spec is not None):
        raise ValueError(
            "When `dataset` is provided, `element_spec` and `components` must "
            "not be specified.")
      options = options_lib.Options()
      options.experimental_distribute.num_devices = len(devices)
      # If `prefetch_buffer_size` is 0, we turn off the `inject_prefetch`
      # optimization to prevent potentially introducing asynchrony.
      if prefetch_buffer_size == 0:
        options.experimental_optimization.inject_prefetch = False
      dataset = dataset.with_options(options)
      dataset = dataset._apply_debug_options()  # pylint: disable=protected-access
      self._element_spec = dataset.element_spec
      experimental_slack = dataset.options().experimental_slack
      self._devices = devices
      self._source_device = source_device
      source_device_tensor = ops.convert_to_tensor(self._source_device)

      if prefetch_buffer_size > max_buffer_size:
        max_buffer_size = prefetch_buffer_size

      # Create the MultiDeviceIterator.
      with ops.device(self._source_device):
        self._multi_device_iterator_resource = (
            gen_dataset_ops.anonymous_multi_device_iterator_v3(
                devices=self._devices, **dataset._flat_structure))  # pylint: disable=protected-access

        # The incarnation ID is used to ensure consistency between the
        # per-device iterators and the multi-device iterator.
        incarnation_id = gen_dataset_ops.multi_device_iterator_init(
            dataset._variant_tensor,  # pylint: disable=protected-access
            self._multi_device_iterator_resource,
            max_buffer_size=max_buffer_size)

      prototype_device_datasets = []
      for i, device in enumerate(self._devices):
        with ops.device(device):
          ds = _PerDeviceGenerator(
              i,
              self._multi_device_iterator_resource,
              incarnation_id,
              source_device_tensor,
              dataset.element_spec,
              iterator_is_anonymous=True,
          )
          prototype_device_datasets.append(ds)

      # TODO(rohanj): Explore the possibility of the MultiDeviceIterator to
      # initialize the device side of the pipeline. This would allow the
      # MultiDeviceIterator to choose, for example, to move some transformations
      # into the device side from its input. It might be useful in rewriting.
      # Create the per device iterators.
      self._device_iterators = []

      for i, device in enumerate(self._devices):
        with ops.device(device):
          ds = _create_device_dataset(prototype_device_datasets[i],
                                      incarnation_id, prefetch_buffer_size,
                                      experimental_slack)
          iterator = iter(ds)
          self._device_iterators.append(iterator)

  def get_next(self, device=None):
    """Returns the next element given a `device`, else returns all in a list."""
    if device is not None:
      index = self._devices.index(device)
      return self._device_iterators[index].get_next()

    result = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        result.append(self._device_iterators[i].get_next())
    return result

  def __iter__(self):
    return self

  def next(self):
    return self.__next__()

  def __next__(self):
    try:
      return self.get_next()
    except errors.OutOfRangeError:
      raise StopIteration

  def get_next_as_optional(self):
    result = []
    for i, device in enumerate(self._devices):
      with ops.device(device):
        result.append(self._device_iterators[i].get_next_as_optional())
    return result

  @property
  def element_spec(self):
    return self._element_spec

  @property
  def _type_spec(self):
    return MultiDeviceIteratorSpec(self._devices, self._source_device,
                                   self._element_spec)
