# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions used by values.py and ps_values.py."""

from tensorflow.python.distribute import distribute_lib
from tensorflow.python.distribute import distribution_strategy_context as ds_context
from tensorflow.python.distribute import reduce_util
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.saved_model import save_context
from tensorflow.python.saved_model import save_options
from tensorflow.python.training.saving import saveable_object


def write_object_proto(var, proto, options):
  """Update a SavedObject proto for the caller.

  If a DistributedVariable object supports this method, it will be called when
  saving with a pre-built `SavedObject` proto representing the object, plus an
  instance of `SaveOptions`. This method is then free to modify that proto
  instance.

  `DistributedVariable` with `AUTO` or `ON_WRITE` synchronization optionally
   write out information about their components to the
   `experimental_distributed_variable_components` field of a
   `SavedVariable` (depending on the `SaveOptions` variable policy).

  Args:
    var: The DistributedVariable object.
    proto: A pre-built `SavedObject` proto for this object. It is assumed this
      will be a `SavedVariable` instance.
    options: A `SaveOptions` instance.
  """
  if options.experimental_variable_policy._expand_distributed_variables(  # pylint: disable=protected-access
  ):
    for var in var.values:
      var_proto = (
          proto.variable.experimental_distributed_variable_components.add())
      var_proto.name = var.name.split(":")[0]
      var_proto.device = var.device


def get_on_write_saveable(var, primary_var, name):
  """Return saveable spec for AUTO and ON_WRITE variables."""
  # We use a callable so that we don't have to evaluate this expression
  # in the case where we are trying to restore instead of save.
  def tensor():
    if context.executing_eagerly() and not primary_var.is_initialized():
      # A SaveSpec tensor value of `None` indicates that the variable is
      # uninitialized.
      return None
    strategy = var.distribute_strategy
    return strategy.extended.read_var(var)

  spec = saveable_object.SaveSpec(
      tensor=tensor,
      slice_spec="",
      name=name,
      dtype=var.dtype,
      device=primary_var.device)

  return tensor, [spec]


def get_on_write_restore_ops(var, tensor):
  """Return restore ops for AUTO and ON_WRITE variables."""
  packed_var = var._packed_variable  # pylint: disable=protected-access
  if packed_var is not None:
    return control_flow_ops.group(
        tuple(
            assign_on_device(d, packed_var, tensor)
            for d in packed_var.devices))
  return control_flow_ops.group(
      tuple(
          assign_on_device(v.device, v, tensor)
          for v in var.values))


def get_on_read_saveable(var, primary_var, name):
  """Return saveables for ON_READ variable."""

  # We use a callable so that we don't have to evaluate this expression
  # in the case where we are trying to restore instead of save.
  def tensor():
    return var._get_cross_replica()  # pylint: disable=protected-access

  spec = saveable_object.SaveSpec(
      tensor=tensor,
      slice_spec="",
      name=name,
      dtype=var.dtype,
      device=primary_var.device)

  return tensor, [spec]


def get_on_read_restore_ops(var, tensor, aggregation):
  """Return restore ops for ON_READ variables."""
  # To preserve the sum across save and restore, we have to divide the
  # total across all devices when restoring a variable that was summed
  # when saving.
  if aggregation == vs.VariableAggregation.SUM:
    strategy = var.distribute_strategy
    tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
                           var.dtype)
  return control_flow_ops.group(
      tuple(
          assign_on_device(v.device, v, tensor)
          for v in var.values))


# Utility function that indicates if you are in an UpdateContext when running
# in a replica fn.
def in_replica_update_context():
  return distribute_lib.get_update_replica_id() is not None


def on_write_assign(var, value, use_locking=False, name=None, read_value=True):
  assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=assign_fn,
      value=value,
      use_locking=use_locking,
      name=name,
      read_value=read_value)


def on_write_assign_add(var, value, use_locking=False, name=None,
                        read_value=True):
  assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=assign_add_fn,
      value=value,
      use_locking=use_locking,
      name=name,
      read_value=read_value)


def on_write_assign_sub(var, value, use_locking=False, name=None,
                        read_value=True):
  assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=assign_sub_fn,
      value=value,
      use_locking=use_locking,
      name=name,
      read_value=read_value)


def assign_on_each_device(var, assign_func, value, read_value):
  """Update the variable on each replica with the given assign_func and value."""
  if var._packed_variable is not None:  # pylint: disable=protected-access
    update = control_flow_ops.group(
        tuple(
            assign_func(d, var._packed_variable, value) for d in var._devices))  # pylint: disable=protected-access
  else:
    update = control_flow_ops.group(
        tuple(assign_func(v.device, v, value) for v in var._values))  # pylint: disable=protected-access
  if not read_value:
    return update
  with ops.control_dependencies([update] if update else []):
    return var.read_value()


def on_read_assign_sub_cross_replica(var, value, read_value=True):
  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
    if ds_context.in_cross_replica_context():
      if var.aggregation == vs.VariableAggregation.SUM:
        raise ValueError(
            "SyncOnReadVariable does not support `assign_sub` in "
            "cross-replica context when aggregation is set to "
            "`tf.VariableAggregation.SUM`.")
      return assign_on_each_device(var, assign_sub_on_device,
                                   value, read_value)


def on_read_assign_add_cross_replica(var, value, read_value=True):
  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
    if ds_context.in_cross_replica_context():
      if var.aggregation == vs.VariableAggregation.SUM:
        raise ValueError(
            "SyncOnReadVariable does not support `assign_add` in "
            "cross-replica context when aggregation is set to "
            "`tf.VariableAggregation.SUM`.")
      return assign_on_each_device(var, assign_add_on_device,
                                   value, read_value)


def on_read_assign_cross_replica(var, value, read_value=True):
  """Return the value of the variable in cross replica context."""
  with ds_context.enter_or_assert_strategy(var.distribute_strategy):
    if ds_context.in_cross_replica_context():
      # To preserve the sum across save and restore, we have to divide the
      # total across all devices when restoring a variable that was summed
      # when saving.
      tensor = value
      if var.aggregation == vs.VariableAggregation.SUM:
        strategy = var._distribute_strategy  # pylint: disable=protected-access
        tensor = math_ops.cast(tensor / strategy.num_replicas_in_sync,
                               var.dtype)
      return assign_on_each_device(var, assign_on_device, tensor,
                                   read_value)


def scatter_sub(var, sparse_delta, use_locking=False, name=None):
  scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_sub_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def scatter_add(var, sparse_delta, use_locking=False, name=None):
  scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_add_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def scatter_mul(var, sparse_delta, use_locking=False, name=None):
  scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_mul_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def scatter_div(var, sparse_delta, use_locking=False, name=None):
  scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_div_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def scatter_min(var, sparse_delta, use_locking=False, name=None):
  scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_min_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def scatter_max(var, sparse_delta, use_locking=False, name=None):
  scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_max_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def scatter_update(var, sparse_delta, use_locking=False, name=None):
  scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
  return var._update(  # pylint: disable=protected-access
      update_fn=scatter_update_fn,
      value=sparse_delta,
      use_locking=use_locking,
      name=name)


def get_current_replica_id_as_int():
  """Returns the current replica ID as an integer, or `None`."""
  replica_context = ds_context.get_replica_context()
  if replica_context:
    replica_id = replica_context._replica_id  # pylint: disable=protected-access
    if not isinstance(replica_id, int):
      replica_id = tensor_util.constant_value(replica_id)
  else:
    replica_id = distribute_lib.get_update_replica_id()
  return replica_id


def assign_on_device(device, variable, tensor):
  with ops.device(device):
    return variable.assign(tensor)


def assign_add_on_device(device, variable, tensor):
  with ops.device(device):
    return variable.assign_add(tensor)


def assign_sub_on_device(device, variable, tensor):
  with ops.device(device):
    return variable.assign_sub(tensor)


def assert_replica_context(strategy):
  replica_context = ds_context.get_replica_context()
  if not replica_context:
    raise RuntimeError(
        "Replica-local variables may only be assigned in a replica context.")
  if replica_context.strategy is not strategy:
    raise RuntimeError(
        "Replica-local variables may only be assigned in a replica context.")


def apply_aggregation(strategy, value, aggregation, destinations):
  if aggregation == vs.VariableAggregation.ONLY_FIRST_REPLICA:
    return strategy.extended.broadcast_to(
        strategy.experimental_local_results(value)[0],
        destinations=destinations)
  reduce_op = reduce_util.ReduceOp.from_variable_aggregation(aggregation)
  return strategy.extended.reduce_to(reduce_op, value, destinations)


aggregation_error_msg = (
    "You must specify an aggregation method to update a "
    "{variable_type} in Replica Context. You can do so by passing "
    "an explicit value for argument `aggregation` to tf.Variable(..)."
    "e.g. `tf.Variable(..., aggregation=tf.VariableAggregation.SUM)`"
    "`tf.VariableAggregation` lists the possible aggregation methods."
    "This is required because {variable_type} should always be "
    "kept in sync. When updating them or assigning to them in a "
    "replica context, we automatically try to aggregate the values "
    "before updating the variable. For this aggregation, we need to "
    "know the aggregation method. "
    "Another alternative is to not try to update such "
    "{variable_type} in replica context, but in cross replica "
    "context. You can enter cross replica context by calling "
    "`tf.distribute.get_replica_context().merge_call(merge_fn, ..)`."
    "Inside `merge_fn`, you can then update the {variable_type} "
    "using `tf.distribute.StrategyExtended.update()`.")


scatter_error_msg = ("{op_name} is only supported for mirrored "
                     "variable (variable created within certain "
                     "`tf.distribute.Strategy` scope) with NONE or "
                     "`ONLY_FIRST_REPLICA` aggregation, got: {aggregation}.")


def is_saving_non_distributed():
  """Returns whether we're saving a non-distributed version of the model.

  It returns True iff we are in saving context and are saving a non-distributed
  version of the model. That is, SaveOptions.experimental_variable_policy is
  NONE.

  Returns:
    A boolean.
  """
  if not save_context.in_save_context():
    return False
  options = save_context.get_save_options()
  return (options.experimental_variable_policy !=
          save_options.VariablePolicy.EXPAND_DISTRIBUTED_VARIABLES)


def mark_as_unsaveable():
  """Marks the function as unsaveable if not inside save context."""
  if ops.inside_function() and not save_context.in_save_context():
    ops.get_default_graph().mark_as_unsaveable("""
ConcreteFunction that uses distributed variables in certain way cannot be saved.
If you're saving with

tf.saved_model.save(..., signatures=f.get_concrete_function())

do

@tf.function(input_signature=...)
def f_with_input_signature():
  ...

tf.saved_model.save(..., signatures=f_with_input_signature)`

instead.""")
