STT-tensorflow/tensorflow/python/ops/list_ops.py
Saurabh Saxena cf09044d9e Handle case when delement is None in PopBackGrad.
Fixes https://github.com/tensorflow/tensorflow/issues/37230

PiperOrigin-RevId: 303845628
Change-Id: Ia0159cb2dfbc70112f822f17e88182e414a83494
2020-03-30 16:22:09 -07:00

374 lines
12 KiB
Python

# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops to manipulate lists of tensors."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_list_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_list_ops import *
# pylint: enable=wildcard-import
from tensorflow.python.util.lazy_loader import LazyLoader
# list_ops -> control_flow_ops -> tensor_array_ops -> list_ops
control_flow_ops = LazyLoader(
"control_flow_ops", globals(),
"tensorflow.python.ops.control_flow_ops")
ops.NotDifferentiable("TensorListConcatLists")
ops.NotDifferentiable("TensorListElementShape")
ops.NotDifferentiable("TensorListLength")
ops.NotDifferentiable("TensorListPushBackBatch")
def empty_tensor_list(element_shape,
element_dtype,
max_num_elements=None,
name=None):
if max_num_elements is None:
max_num_elements = -1
return gen_list_ops.empty_tensor_list(
element_shape=_build_element_shape(element_shape),
element_dtype=element_dtype,
max_num_elements=max_num_elements,
name=name)
def tensor_list_reserve(element_shape, num_elements, element_dtype, name=None):
return gen_list_ops.tensor_list_reserve(
element_shape=_build_element_shape(element_shape),
num_elements=num_elements,
element_dtype=element_dtype,
name=name)
def tensor_list_from_tensor(tensor, element_shape, name=None):
return gen_list_ops.tensor_list_from_tensor(
tensor=tensor,
element_shape=_build_element_shape(element_shape),
name=name)
def tensor_list_get_item(input_handle, index, element_dtype, element_shape=None,
name=None):
return gen_list_ops.tensor_list_get_item(
input_handle=input_handle,
index=index,
element_shape=_build_element_shape(element_shape),
element_dtype=element_dtype,
name=name)
def tensor_list_pop_back(input_handle, element_dtype, name=None):
return gen_list_ops.tensor_list_pop_back(
input_handle=input_handle,
element_shape=-1,
element_dtype=element_dtype,
name=name)
def tensor_list_gather(input_handle,
indices,
element_dtype,
element_shape=None,
name=None):
return gen_list_ops.tensor_list_gather(
input_handle=input_handle,
indices=indices,
element_shape=_build_element_shape(element_shape),
element_dtype=element_dtype,
name=name)
def tensor_list_scatter(tensor,
indices,
element_shape=None,
input_handle=None,
name=None):
if input_handle is not None:
return gen_list_ops.tensor_list_scatter_into_existing_list(
input_handle=input_handle, tensor=tensor, indices=indices, name=name)
else:
return gen_list_ops.tensor_list_scatter_v2(
tensor=tensor,
indices=indices,
element_shape=_build_element_shape(element_shape),
num_elements=-1,
name=name)
def tensor_list_stack(input_handle,
element_dtype,
num_elements=-1,
element_shape=None,
name=None):
return gen_list_ops.tensor_list_stack(
input_handle=input_handle,
element_shape=_build_element_shape(element_shape),
element_dtype=element_dtype,
num_elements=num_elements,
name=name)
def tensor_list_concat(input_handle, element_dtype, element_shape=None,
name=None):
# Ignore the lengths output of TensorListConcat. It is only used during
# gradient computation.
return gen_list_ops.tensor_list_concat_v2(
input_handle=input_handle,
element_dtype=element_dtype,
element_shape=_build_element_shape(element_shape),
leading_dims=ops.convert_to_tensor([], dtype=dtypes.int64),
name=name)[0]
def tensor_list_split(tensor, element_shape, lengths, name=None):
return gen_list_ops.tensor_list_split(
tensor=tensor,
element_shape=_build_element_shape(element_shape),
lengths=lengths,
name=name)
def tensor_list_set_item(input_handle,
index,
item,
resize_if_index_out_of_bounds=False,
name=None):
"""Sets `item` at `index` in input list."""
if resize_if_index_out_of_bounds:
input_list_size = gen_list_ops.tensor_list_length(input_handle)
# TODO(srbs): This could cause some slowdown. Consider fusing resize
# functionality in the SetItem op.
input_handle = control_flow_ops.cond(
index >= input_list_size,
lambda: gen_list_ops.tensor_list_resize( # pylint: disable=g-long-lambda
input_handle, index + 1),
lambda: input_handle)
return gen_list_ops.tensor_list_set_item(
input_handle=input_handle, index=index, item=item, name=name)
@ops.RegisterGradient("TensorListPushBack")
def _PushBackGrad(op, dresult):
return gen_list_ops.tensor_list_pop_back(
dresult,
element_shape=array_ops.shape(op.inputs[1]),
element_dtype=op.get_attr("element_dtype"))
@ops.RegisterGradient("TensorListPopBack")
def _PopBackGrad(op, dlist, delement):
if dlist is None:
dlist = empty_tensor_list(
element_dtype=delement.dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
if delement is None:
delement = array_ops.zeros_like(op.outputs[1])
return gen_list_ops.tensor_list_push_back(dlist, delement), None
@ops.RegisterGradient("TensorListStack")
def _TensorListStackGrad(unused_op, dtensor):
return tensor_list_from_tensor(dtensor, element_shape=dtensor.shape[1:]), None
@ops.RegisterGradient("TensorListConcat")
@ops.RegisterGradient("TensorListConcatV2")
def _TensorListConcatGrad(op, dtensor, unused_dlengths):
"""Gradient function for TensorListConcat."""
dlist = tensor_list_split(
dtensor,
element_shape=gen_list_ops.tensor_list_element_shape(
op.inputs[0], shape_type=dtypes.int32),
lengths=op.outputs[1])
if op.type == "TensorListConcatV2":
return dlist, None, None
else:
return dlist
@ops.RegisterGradient("TensorListSplit")
def _TensorListSplitGrad(op, dlist):
tensor, _, lengths = op.inputs
element_shape = array_ops.slice(array_ops.shape(tensor), [1], [-1])
element_shape = array_ops.concat([[-1], element_shape], axis=0)
return gen_list_ops.tensor_list_concat_v2(
dlist,
element_shape=element_shape,
leading_dims=lengths,
element_dtype=op.inputs[0].dtype)[0], None, None
@ops.RegisterGradient("TensorListFromTensor")
def _TensorListFromTensorGrad(op, dlist):
"""Gradient for TensorListFromTensor."""
t = op.inputs[0]
if t.shape.dims and t.shape.dims[0].value is not None:
num_elements = t.shape.dims[0].value
else:
num_elements = None
if dlist is None:
dlist = empty_tensor_list(
element_dtype=t.dtype,
element_shape=gen_list_ops.tensor_list_element_shape(
op.outputs[0], shape_type=dtypes.int32))
tensor_grad = gen_list_ops.tensor_list_stack(
dlist,
element_shape=array_ops.slice(array_ops.shape(t), [1], [-1]),
element_dtype=t.dtype,
num_elements=num_elements)
shape_grad = None
return tensor_grad, shape_grad
@ops.RegisterGradient("TensorListGetItem")
def _TensorListGetItemGrad(op, ditem):
"""Gradient for TensorListGetItem."""
list_size = gen_list_ops.tensor_list_length(op.inputs[0])
list_grad = gen_list_ops.tensor_list_set_item(
gen_list_ops.tensor_list_reserve(
gen_list_ops.tensor_list_element_shape(op.inputs[0],
shape_type=dtypes.int32),
list_size, element_dtype=ditem.dtype),
index=op.inputs[1],
item=ditem)
index_grad = None
element_shape_grad = None
return list_grad, index_grad, element_shape_grad
@ops.RegisterGradient("TensorListSetItem")
def _TensorListSetItemGrad(op, dlist):
"""Gradient function for TensorListSetItem."""
_, index, item = op.inputs
list_grad = gen_list_ops.tensor_list_set_item(
dlist, index=index, item=array_ops.zeros_like(item))
index_grad = None
element_grad = tensor_list_get_item(
dlist,
index,
element_shape=array_ops.shape(item),
element_dtype=item.dtype)
return list_grad, index_grad, element_grad
@ops.RegisterGradient("TensorListResize")
def _TensorListResizeGrad(op, dlist):
input_list, _ = op.inputs
input_list_size = gen_list_ops.tensor_list_length(input_list)
return gen_list_ops.tensor_list_resize(dlist, input_list_size), None
@ops.RegisterGradient("TensorListGather")
def _TensorListGatherGrad(op, dtensor):
"""Gradient function for TensorListGather."""
input_list, indices, _ = op.inputs
element_shape = gen_list_ops.tensor_list_element_shape(
input_list, shape_type=dtypes.int32)
num_elements = gen_list_ops.tensor_list_length(input_list)
dlist = tensor_list_reserve(element_shape, num_elements, dtensor.dtype)
dlist = tensor_list_scatter(
tensor=dtensor, indices=indices, input_handle=dlist)
return dlist, None, None
@ops.RegisterGradient("TensorListScatter")
@ops.RegisterGradient("TensorListScatterV2")
def _TensorListScatterGrad(op, dlist):
"""Gradient function for TensorListScatter."""
tensor = op.inputs[0]
indices = op.inputs[1]
dtensor = gen_list_ops.tensor_list_gather(
dlist,
indices,
element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
element_dtype=tensor.dtype)
if op.type == "TensorListScatterV2":
return dtensor, None, None, None
else:
return dtensor, None, None
@ops.RegisterGradient("TensorListScatterIntoExistingList")
def _TensorListScatterIntoExistingListGrad(op, dlist):
"""Gradient function for TensorListScatterIntoExistingList."""
_, tensor, indices = op.inputs
dtensor = gen_list_ops.tensor_list_gather(
dlist,
indices,
element_shape=array_ops.slice(array_ops.shape(tensor), [1], [-1]),
element_dtype=tensor.dtype)
zeros = array_ops.zeros_like(tensor)
dlist = tensor_list_scatter(zeros, indices, indices, input_handle=dlist)
return dlist, dtensor, None
def _build_element_shape(shape):
"""Converts shape to a format understood by list_ops for element_shape.
If `shape` is already a `Tensor` it is returned as-is. We do not perform a
type check here.
If shape is None or a TensorShape with unknown rank, -1 is returned.
If shape is a scalar, an int32 tensor with empty list is returned. Note we
do directly return an empty list since ops.convert_to_tensor would conver it
to a float32 which is not a valid type for element_shape.
If shape is a sequence of dims, None's in the list are replaced with -1. We
do not check the dtype of the other dims.
Args:
shape: Could be None, Tensor, TensorShape or a list of dims (each dim could
be a None, scalar or Tensor).
Returns:
A None-free shape that can be converted to a tensor.
"""
if isinstance(shape, ops.Tensor):
return shape
if isinstance(shape, tensor_shape.TensorShape):
# `TensorShape.as_list` requires rank to be known.
shape = shape.as_list() if shape else None
# Shape is unknown.
if shape is None:
return -1
# Shape is a scalar.
if not shape:
return ops.convert_to_tensor(shape, dtype=dtypes.int32)
# Shape is a sequence of dimensions. Convert None dims to -1.
def convert(val):
if val is None:
return -1
if isinstance(val, ops.Tensor):
return val
if isinstance(val, tensor_shape.Dimension):
return val.value if val.value is not None else -1
return val
return [convert(d) for d in shape]