# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops to manipulate lists of tensors."""

# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_list_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_list_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader

# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops
control_flow_ops = LazyLoader(
    "control_flow_ops", globals(),
    "tensorflow.python.ops.control_flow_ops")


ops.NotDifferentiable("TensorListConcatLists")
ops.NotDifferentiable("TensorListElementShape")
ops.NotDifferentiable("TensorListLength")
ops.NotDifferentiable("TensorListPushBackBatch")


def empty_tensor_list(element_shape,
                      element_dtype,
                      max_num_elements=None,
                      name=None):
  if max_num_elements is None:
    max_num_elements = -1

  return gen_list_ops.empty_tensor_list(
      element_shape=_build_element_shape(element_shape),
      element_dtype=element_dtype,
      max_num_elements=max_num_elements,
      name=name)


def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
  return gen_list_ops.tensor_list_reserve(
      element_shape=_build_element_shape(element_shape),
      num_elements=num_elements,
      element_dtype=element_dtype,
      name=name)


def tensor_list_from_tensor(tensor, element_shape, name=None):
  return gen_list_ops.tensor_list_from_tensor(
      tensor=tensor,
      element_shape=_build_element_shape(element_shape),
      name=name)


def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
                         name=None):
  return gen_list_ops.tensor_list_get_item(
      input_handle=input_handle,
      index=index,
      element_shape=_build_element_shape(element_shape),
      element_dtype=element_dtype,
      name=name)


def tensor_list_pop_back(input_handle, element_dtype, name=None):
  return gen_list_ops.tensor_list_pop_back(
      input_handle=input_handle,
      element_shape=-1,
      element_dtype=element_dtype,
      name=name)


def tensor_list_gather(input_handle,
                       indices,
                       element_dtype,
                       element_shape=None,
                       name=None):
  return gen_list_ops.tensor_list_gather(
      input_handle=input_handle,
      indices=indices,
      element_shape=_build_element_shape(element_shape),
      element_dtype=element_dtype,
      name=name)


def tensor_list_scatter(tensor,
                        indices,
                        element_shape=None,
                        input_handle=None,
                        name=None):
  if input_handle is not None:
    return gen_list_ops.tensor_list_scatter_into_existing_list(
        input_handle=input_handle, tensor=tensor, indices=indices, name=name)
  else:
    return gen_list_ops.tensor_list_scatter_v2(
        tensor=tensor,
        indices=indices,
        element_shape=_build_element_shape(element_shape),
        num_elements=-1,
        name=name)


def tensor_list_stack(input_handle,
                      element_dtype,
                      num_elements=-1,
                      element_shape=None,
                      name=None):
  return gen_list_ops.tensor_list_stack(
      input_handle=input_handle,
      element_shape=_build_element_shape(element_shape),
      element_dtype=element_dtype,
      num_elements=num_elements,
      name=name)


def tensor_list_concat(input_handle, element_dtype, element_shape=None,
                       name=None):
  # Ignore the lengths output of TensorListConcat. It is only used during
  # gradient computation.
  return gen_list_ops.tensor_list_concat_v2(
      input_handle=input_handle,
      element_dtype=element_dtype,
      element_shape=_build_element_shape(element_shape),
      leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
      name=name)[0]


def tensor_list_split(tensor, element_shape, lengths, name=None):
  return gen_list_ops.tensor_list_split(
      tensor=tensor,
      element_shape=_build_element_shape(element_shape),
      lengths=lengths,
      name=name)


def tensor_list_set_item(input_handle,
                         index,
                         item,
                         resize_if_index_out_of_bounds=False,
                         name=None):
  """Sets `item` at `index` in input list."""
  if resize_if_index_out_of_bounds:
    input_list_size = gen_list_ops.tensor_list_length(input_handle)
    # TODO(srbs): This could cause some slowdown. Consider fusing resize
    # functionality in the SetItem op.
    input_handle = control_flow_ops.cond(
        index >= input_list_size,
        lambda: gen_list_ops.tensor_list_resize(  # pylint: disable=g-long-lambda
            input_handle, index + 1),
        lambda: input_handle)
  return gen_list_ops.tensor_list_set_item(
      input_handle=input_handle, index=index, item=item, name=name)


@ops.RegisterGradient("TensorListPushBack")
def _PushBackGrad(op, dresult):
  return gen_list_ops.tensor_list_pop_back(
      dresult,
      element_shape=array_ops.shape(op.inputs[1]),
      element_dtype=op.get_attr("element_dtype"))


@ops.RegisterGradient("TensorListPopBack")
def _PopBackGrad(op, dlist, delement):
  if dlist is None:
    dlist = empty_tensor_list(
        element_dtype=delement.dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  if delement is None:
    delement = array_ops.zeros_like(op.outputs[1])
  return gen_list_ops.tensor_list_push_back(dlist, delement), None


@ops.RegisterGradient("TensorListStack")
def _TensorListStackGrad(unused_op, dtensor):
  return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None


@ops.RegisterGradient("TensorListConcat")
@ops.RegisterGradient("TensorListConcatV2")
def _TensorListConcatGrad(op, dtensor, unused_dlengths):
  """Gradient function for TensorListConcat."""
  dlist = tensor_list_split(
      dtensor,
      element_shape=gen_list_ops.tensor_list_element_shape(
          op.inputs[0], shape_type=dtypes.int32),
      lengths=op.outputs[1])
  if op.type == "TensorListConcatV2":
    return dlist, None, None
  else:
    return dlist


@ops.RegisterGradient("TensorListSplit")
def _TensorListSplitGrad(op, dlist):
  tensor, _, lengths = op.inputs
  element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
  element_shape = array_ops.concat([[-1], element_shape], axis=0)
  return gen_list_ops.tensor_list_concat_v2(
      dlist,
      element_shape=element_shape,
      leading_dims=lengths,
      element_dtype=op.inputs[0].dtype)[0], None, None


@ops.RegisterGradient("TensorListFromTensor")
def _TensorListFromTensorGrad(op, dlist):
  """Gradient for TensorListFromTensor."""
  t = op.inputs[0]
  if t.shape.dims and t.shape.dims[0].value is not None:
    num_elements = t.shape.dims[0].value
  else:
    num_elements = None
  if dlist is None:
    dlist = empty_tensor_list(
        element_dtype=t.dtype,
        element_shape=gen_list_ops.tensor_list_element_shape(
            op.outputs[0], shape_type=dtypes.int32))
  tensor_grad = gen_list_ops.tensor_list_stack(
      dlist,
      element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
      element_dtype=t.dtype,
      num_elements=num_elements)
  shape_grad = None
  return tensor_grad, shape_grad


@ops.RegisterGradient("TensorListGetItem")
def _TensorListGetItemGrad(op, ditem):
  """Gradient for TensorListGetItem."""
  list_size = gen_list_ops.tensor_list_length(op.inputs[0])
  list_grad = gen_list_ops.tensor_list_set_item(
      gen_list_ops.tensor_list_reserve(
          gen_list_ops.tensor_list_element_shape(op.inputs[0],
                                                 shape_type=dtypes.int32),
          list_size, element_dtype=ditem.dtype),
      index=op.inputs[1],
      item=ditem)
  index_grad = None
  element_shape_grad = None
  return list_grad, index_grad, element_shape_grad


@ops.RegisterGradient("TensorListSetItem")
def _TensorListSetItemGrad(op, dlist):
  """Gradient function for TensorListSetItem."""
  _, index, item = op.inputs
  list_grad = gen_list_ops.tensor_list_set_item(
      dlist, index=index, item=array_ops.zeros_like(item))
  index_grad = None
  element_grad = tensor_list_get_item(
      dlist,
      index,
      element_shape=array_ops.shape(item),
      element_dtype=item.dtype)
  return list_grad, index_grad, element_grad


@ops.RegisterGradient("TensorListResize")
def _TensorListResizeGrad(op, dlist):
  input_list, _ = op.inputs
  input_list_size = gen_list_ops.tensor_list_length(input_list)
  return gen_list_ops.tensor_list_resize(dlist, input_list_size), None


@ops.RegisterGradient("TensorListGather")
def _TensorListGatherGrad(op, dtensor):
  """Gradient function for TensorListGather."""
  input_list, indices, _ = op.inputs
  element_shape = gen_list_ops.tensor_list_element_shape(
      input_list, shape_type=dtypes.int32)
  num_elements = gen_list_ops.tensor_list_length(input_list)
  dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
  dlist = tensor_list_scatter(
      tensor=dtensor, indices=indices, input_handle=dlist)
  return dlist, None, None


@ops.RegisterGradient("TensorListScatter")
@ops.RegisterGradient("TensorListScatterV2")
def _TensorListScatterGrad(op, dlist):
  """Gradient function for TensorListScatter."""
  tensor = op.inputs[0]
  indices = op.inputs[1]
  dtensor = gen_list_ops.tensor_list_gather(
      dlist,
      indices,
      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
      element_dtype=tensor.dtype)
  if op.type == "TensorListScatterV2":
    return dtensor, None, None, None
  else:
    return dtensor, None, None


@ops.RegisterGradient("TensorListScatterIntoExistingList")
def _TensorListScatterIntoExistingListGrad(op, dlist):
  """Gradient function for TensorListScatterIntoExistingList."""
  _, tensor, indices = op.inputs
  dtensor = gen_list_ops.tensor_list_gather(
      dlist,
      indices,
      element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
      element_dtype=tensor.dtype)
  zeros = array_ops.zeros_like(tensor)
  dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
  return dlist, dtensor, None


def _build_element_shape(shape):
  """Converts shape to a format understood by list_ops for element_shape.

  If `shape` is already a `Tensor` it is returned as-is. We do not perform a
  type check here.

  If shape is None or a TensorShape with unknown rank, -1 is returned.

  If shape is a scalar, an int32 tensor with empty list is returned. Note we
  do directly return an empty list since ops.convert_to_tensor would conver it
  to a float32 which is not a valid type for element_shape.

  If shape is a sequence of dims, None's in the list are replaced with -1. We
  do not check the dtype of the other dims.

  Args:
    shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
      be a None, scalar or Tensor).

  Returns:
    A None-free shape that can be converted to a tensor.
  """
  if isinstance(shape, ops.Tensor):
    return shape
  if isinstance(shape, tensor_shape.TensorShape):
    # `TensorShape.as_list` requires rank to be known.
    shape = shape.as_list() if shape else None
  # Shape is unknown.
  if shape is None:
    return -1
  # Shape is a scalar.
  if not shape:
    return ops.convert_to_tensor(shape, dtype=dtypes.int32)
  # Shape is a sequence of dimensions. Convert None dims to -1.
  def convert(val):
    if val is None:
      return -1
    if isinstance(val, ops.Tensor):
      return val
    if isinstance(val, tensor_shape.Dimension):
      return val.value if val.value is not None else -1
    return val

  return [convert(d) for d in shape]