# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
# pylint: disable=g-classes-have-attributes
"""Contains the InputSpec class."""

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.keras import backend
from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import keras_export
from tensorflow.python.util.tf_export import tf_export


@keras_export('keras.layers.InputSpec')
@tf_export(v1=['layers.InputSpec'])
class InputSpec(object):
  """Specifies the rank, dtype and shape of every input to a layer.

  Layers can expose (if appropriate) an `input_spec` attribute:
  an instance of `InputSpec`, or a nested structure of `InputSpec` instances
  (one per input tensor). These objects enable the layer to run input
  compatibility checks for input structure, input rank, input shape, and
  input dtype.

  A None entry in a shape is compatible with any dimension,
  a None shape is compatible with any shape.

  Args:
    dtype: Expected DataType of the input.
    shape: Shape tuple, expected shape of the input
      (may include None for unchecked axes). Includes the batch size.
    ndim: Integer, expected rank of the input.
    max_ndim: Integer, maximum rank of the input.
    min_ndim: Integer, minimum rank of the input.
    axes: Dictionary mapping integer axes to
      a specific dimension value.
    allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
      as the last axis of the input is 1, as well as inputs of rank N-1
      as long as the last axis of the spec is 1.
    name: Expected key corresponding to this input when passing data as
      a dictionary.

  Example:

  ```python
  class MyLayer(Layer):
      def __init__(self):
          super(MyLayer, self).__init__()
          # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
          # and raise an appropriate error message otherwise.
          self.input_spec = InputSpec(
              shape=(None, 28, 28, 1),
              allow_last_axis_squeeze=True)
  ```
  """

  def __init__(self,
               dtype=None,
               shape=None,
               ndim=None,
               max_ndim=None,
               min_ndim=None,
               axes=None,
               allow_last_axis_squeeze=False,
               name=None):
    self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
    shape = tensor_shape.TensorShape(shape)
    if shape.rank is None:
      shape = None
    else:
      shape = tuple(shape.as_list())
    if shape is not None:
      self.ndim = len(shape)
      self.shape = shape
    else:
      self.ndim = ndim
      self.shape = None
    self.max_ndim = max_ndim
    self.min_ndim = min_ndim
    self.name = name
    self.allow_last_axis_squeeze = allow_last_axis_squeeze
    try:
      axes = axes or {}
      self.axes = {int(k): axes[k] for k in axes}
    except (ValueError, TypeError):
      raise TypeError('The keys in axes must be integers.')

    if self.axes and (self.ndim is not None or self.max_ndim is not None):
      max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
      max_axis = max(self.axes)
      if max_axis > max_dim:
        raise ValueError('Axis {} is greater than the maximum allowed value: {}'
                         .format(max_axis, max_dim))

  def __repr__(self):
    spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
            ('shape=' + str(self.shape)) if self.shape else '',
            ('ndim=' + str(self.ndim)) if self.ndim else '',
            ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
            ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
            ('axes=' + str(self.axes)) if self.axes else '']
    return 'InputSpec(%s)' % ', '.join(x for x in spec if x)

  def get_config(self):
    return {
        'dtype': self.dtype,
        'shape': self.shape,
        'ndim': self.ndim,
        'max_ndim': self.max_ndim,
        'min_ndim': self.min_ndim,
        'axes': self.axes}

  @classmethod
  def from_config(cls, config):
    return cls(**config)


def to_tensor_shape(spec):
  """Returns a tf.TensorShape object that matches the shape specifications.

  If the InputSpec's shape or ndim is defined, this method will return a fully
  or partially-known shape. Otherwise, the returned TensorShape is None.

  Args:
    spec: an InputSpec object.

  Returns:
    a tf.TensorShape object
  """
  if spec.ndim is None and spec.shape is None:
    return tensor_shape.TensorShape(None)
  elif spec.shape is not None:
    return tensor_shape.TensorShape(spec.shape)
  else:
    shape = [None] * spec.ndim
    for a in spec.axes:
      shape[a] = spec.axes[a]  # Assume that axes is defined
    return tensor_shape.TensorShape(shape)


def assert_input_compatibility(input_spec, inputs, layer_name):
  """Checks compatibility between the layer and provided inputs.

  This checks that the tensor(s) `inputs` verify the input assumptions
  of a layer (if any). If not, a clear and actional exception gets raised.

  Args:
      input_spec: An InputSpec instance, list of InputSpec instances, a nested
          structure of InputSpec instances, or None.
      inputs: Input tensor, list of input tensors, or a nested structure of
          input tensors.
      layer_name: String, name of the layer (for error message formatting).

  Raises:
      ValueError: in case of mismatch between
          the provided inputs and the expectations of the layer.
  """
  if not input_spec:
    return

  input_spec = nest.flatten(input_spec)
  if isinstance(inputs, dict):
    # Flatten `inputs` by reference order if input spec names are provided
    names = [spec.name for spec in input_spec]
    if all(names):
      list_inputs = []
      for name in names:
        if name not in inputs:
          raise ValueError('Missing data for input "%s". '
                           'You passed a data dictionary with keys %s. '
                           'Expected the following keys: %s' %
                           (name, list(inputs.keys()), names))
        list_inputs.append(inputs[name])
      inputs = list_inputs

  inputs = nest.flatten(inputs)
  for x in inputs:
    # Having a shape/dtype is the only commonality of the various tensor-like
    # objects that may be passed. The most common kind of invalid type we are
    # guarding for is a Layer instance (Functional API), which does not
    # have a `shape` attribute.
    if not hasattr(x, 'shape'):
      raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))

  if len(inputs) != len(input_spec):
    raise ValueError('Layer ' + layer_name + ' expects ' +
                     str(len(input_spec)) + ' input(s), '
                     'but it received ' + str(len(inputs)) +
                     ' input tensors. Inputs received: ' + str(inputs))
  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
    if spec is None:
      continue

    shape = tensor_shape.TensorShape(x.shape)
    if shape.rank is None:
      return
    # Check ndim.
    if spec.ndim is not None and not spec.allow_last_axis_squeeze:
      ndim = shape.rank
      if ndim != spec.ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected ndim=' + str(spec.ndim) + ', found ndim=' +
                         str(ndim) + '. Full shape received: ' +
                         str(tuple(shape)))
    if spec.max_ndim is not None:
      ndim = x.shape.rank
      if ndim is not None and ndim > spec.max_ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected max_ndim=' + str(spec.max_ndim) +
                         ', found ndim=' + str(ndim))
    if spec.min_ndim is not None:
      ndim = x.shape.rank
      if ndim is not None and ndim < spec.min_ndim:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         ': expected min_ndim=' + str(spec.min_ndim) +
                         ', found ndim=' + str(ndim) +
                         '. Full shape received: ' +
                         str(tuple(shape)))
    # Check dtype.
    if spec.dtype is not None:
      if x.dtype.name != spec.dtype:
        raise ValueError('Input ' + str(input_index) + ' of layer ' +
                         layer_name + ' is incompatible with the layer: '
                         'expected dtype=' + str(spec.dtype) +
                         ', found dtype=' + str(x.dtype))

    # Check specific shape axes.
    shape_as_list = shape.as_list()
    if spec.axes:
      for axis, value in spec.axes.items():
        if hasattr(value, 'value'):
          value = value.value
        if value is not None and shape_as_list[int(axis)] not in {value, None}:
          raise ValueError(
              'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
              ' incompatible with the layer: expected axis ' + str(axis) +
              ' of input shape to have value ' + str(value) +
              ' but received input with shape ' + display_shape(x.shape))
    # Check shape.
    if spec.shape is not None and shape.rank is not None:
      spec_shape = spec.shape
      if spec.allow_last_axis_squeeze:
        if shape_as_list and shape_as_list[-1] == 1:
          shape_as_list = shape_as_list[:-1]
        if spec_shape and spec_shape[-1] == 1:
          spec_shape = spec_shape[:-1]
      for spec_dim, dim in zip(spec_shape, shape_as_list):
        if spec_dim is not None and dim is not None:
          if spec_dim != dim:
            raise ValueError('Input ' + str(input_index) +
                             ' is incompatible with layer ' + layer_name +
                             ': expected shape=' + str(spec.shape) +
                             ', found shape=' + display_shape(x.shape))


def display_shape(shape):
  return str(tuple(shape.as_list()))


def to_tensor_spec(input_spec, default_dtype=None):
  """Converts a Keras InputSpec object to a TensorSpec."""
  default_dtype = default_dtype or backend.floatx()
  if isinstance(input_spec, InputSpec):
    dtype = input_spec.dtype or default_dtype
    return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
  return tensor_spec.TensorSpec(None, default_dtype)
