# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Private base class for layers that can merge several inputs into one."""

import tensorflow.compat.v2 as tf

from keras import backend
from keras.engine.base_layer import Layer
from keras.utils import tf_utils


class _Merge(Layer):
    """Generic merge layer for elementwise merge functions.

    Used to implement `Sum`, `Average`, etc.
    """

    def __init__(self, **kwargs):
        """Initializes a Merge layer.

        Args:
          **kwargs: standard layer keyword arguments.
        """
        super().__init__(**kwargs)
        self.supports_masking = True

    def _merge_function(self, inputs):
        raise NotImplementedError

    def _compute_elemwise_op_output_shape(self, shape1, shape2):
        """Computes the shape of the resultant of an elementwise operation.

        Args:
            shape1: tuple or None. Shape of the first tensor
            shape2: tuple or None. Shape of the second tensor

        Returns:
            expected output shape when an element-wise operation is
            carried out on 2 tensors with shapes shape1 and shape2.
            tuple or None.

        Raises:
            ValueError: if shape1 and shape2 are not compatible for
                element-wise operations.
        """
        if None in [shape1, shape2]:
            return None
        elif len(shape1) < len(shape2):
            return self._compute_elemwise_op_output_shape(shape2, shape1)
        elif not shape2:
            return shape1
        output_shape = list(shape1[: -len(shape2)])
        for i, j in zip(shape1[-len(shape2) :], shape2):
            if i is None or j is None:
                output_shape.append(None)
            elif i == 1:
                output_shape.append(j)
            elif j == 1:
                output_shape.append(i)
            else:
                if i != j:
                    raise ValueError(
                        "Inputs have incompatible shapes. "
                        f"Received shapes {shape1} and {shape2}"
                    )
                output_shape.append(i)
        return tuple(output_shape)

    @tf_utils.shape_type_conversion
    def build(self, input_shape):
        # Used purely for shape validation.
        if not isinstance(input_shape[0], tuple):
            raise ValueError(
                "A merge layer should be called on a list of inputs. "
                f"Received: input_shape={input_shape} (not a list of shapes)"
            )
        if len(input_shape) < 1:
            raise ValueError(
                "A merge layer should be called "
                "on a list of at least 1 input. "
                f"Got {len(input_shape)} inputs. "
                f"Full input_shape received: {input_shape}"
            )
        batch_sizes = {s[0] for s in input_shape if s} - {None}
        if len(batch_sizes) > 1:
            raise ValueError(
                "Cannot merge tensors with different batch sizes. "
                f"Got tensors with shapes {input_shape}"
            )
        if input_shape[0] is None:
            output_shape = None
        else:
            output_shape = input_shape[0][1:]
        for i in range(1, len(input_shape)):
            if input_shape[i] is None:
                shape = None
            else:
                shape = input_shape[i][1:]
            output_shape = self._compute_elemwise_op_output_shape(
                output_shape, shape
            )
        # If the inputs have different ranks, we have to reshape them
        # to make them broadcastable.
        if None not in input_shape and len(set(map(len, input_shape))) == 1:
            self._reshape_required = False
        else:
            self._reshape_required = True

    def call(self, inputs):
        if not isinstance(inputs, (list, tuple)):
            raise ValueError(
                "A merge layer should be called on a list of inputs. "
                f"Received: inputs={inputs} (not a list of tensors)"
            )
        if self._reshape_required:
            reshaped_inputs = []
            input_ndims = list(map(backend.ndim, inputs))
            if None not in input_ndims:
                # If ranks of all inputs are available,
                # we simply expand each of them at axis=1
                # until all of them have the same rank.
                max_ndim = max(input_ndims)
                for x in inputs:
                    x_ndim = backend.ndim(x)
                    for _ in range(max_ndim - x_ndim):
                        x = tf.expand_dims(x, axis=1)
                    reshaped_inputs.append(x)
                return self._merge_function(reshaped_inputs)
            else:
                # Transpose all inputs so that batch size is the last dimension.
                # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... ,
                # batch_size)
                transposed = False
                for x in inputs:
                    x_ndim = backend.ndim(x)
                    if x_ndim is None:
                        x_shape = tf.shape(x)
                        batch_size = x_shape[0]
                        new_shape = backend.concatenate(
                            [x_shape[1:], tf.expand_dims(batch_size, axis=-1)]
                        )
                        x_transposed = tf.reshape(
                            x,
                            tf.stack(
                                [batch_size, tf.reduce_prod(x_shape[1:])],
                                axis=0,
                            ),
                        )
                        x_transposed = tf.transpose(x_transposed, perm=(1, 0))
                        x_transposed = tf.reshape(x_transposed, new_shape)
                        reshaped_inputs.append(x_transposed)
                        transposed = True
                    elif x_ndim > 1:
                        dims = list(range(1, x_ndim)) + [0]
                        reshaped_inputs.append(tf.transpose(x, perm=dims))
                        transposed = True
                    else:
                        # We don't transpose inputs if they are 1D vectors or
                        # scalars.
                        reshaped_inputs.append(x)
                y = self._merge_function(reshaped_inputs)
                y_ndim = backend.ndim(y)
                if transposed:
                    # If inputs have been transposed, we have to transpose the
                    # output too.
                    if y_ndim is None:
                        y_shape = tf.shape(y)
                        y_ndim = tf.shape(y_shape)[0]
                        batch_size = y_shape[y_ndim - 1]
                        new_shape = backend.concatenate(
                            [
                                tf.expand_dims(batch_size, axis=-1),
                                y_shape[: y_ndim - 1],
                            ]
                        )
                        y = tf.reshape(y, (-1, batch_size))
                        y = tf.transpose(y, perm=(1, 0))
                        y = tf.reshape(y, new_shape)
                    elif y_ndim > 1:
                        dims = [y_ndim - 1] + list(range(y_ndim - 1))
                        y = tf.transpose(y, perm=dims)
                return y
        else:
            return self._merge_function(inputs)

    @tf_utils.shape_type_conversion
    def compute_output_shape(self, input_shape):
        if input_shape[0] is None:
            output_shape = None
        else:
            output_shape = input_shape[0][1:]
        for i in range(1, len(input_shape)):
            if input_shape[i] is None:
                shape = None
            else:
                shape = input_shape[i][1:]
            output_shape = self._compute_elemwise_op_output_shape(
                output_shape, shape
            )
        batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
        if len(batch_sizes) == 1:
            output_shape = (list(batch_sizes)[0],) + output_shape
        else:
            output_shape = (None,) + output_shape
        return output_shape

    def compute_mask(self, inputs, mask=None):
        if mask is None:
            return None
        if not isinstance(mask, (tuple, list)):
            raise ValueError(f"`mask` should be a list. Received: mask={mask}")
        if not isinstance(inputs, (tuple, list)):
            raise ValueError(
                f"`inputs` should be a list. Received: inputs={inputs}"
            )
        if len(mask) != len(inputs):
            raise ValueError(
                "The lists `inputs` and `mask` should have the same length. "
                f"Received: inputs={inputs} of length {len(inputs)}, and "
                f"mask={mask} of length {len(mask)}"
            )
        if all(m is None for m in mask):
            return None
        masks = [tf.expand_dims(m, axis=0) for m in mask if m is not None]
        return backend.all(
            backend.concatenate(masks, axis=0), axis=0, keepdims=False
        )

    def get_config(self):
        return super().get_config()
