Allows TensorListScatter to scatter at non-contiguous indices to make it consistent with TensorArray.scatter.

PiperOrigin-RevId: 227874653
This commit is contained in:
Saurabh Saxena 2019-01-04 10:17:31 -08:00 committed by TensorFlower Gardener
parent 07816cc1e1
commit daa75aff18
4 changed files with 77 additions and 14 deletions

View File

@ -521,14 +521,31 @@ class TensorListScatter : public OpKernel {
"Specified a list with shape ", element_shape.DebugString(),
" from a tensor with shape ", output_shape.DebugString()));
output_list.element_shape = element_shape;
output_list.tensors.reserve(indices.NumElements());
OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
errors::InvalidArgument(
"Invalid number of rows in input tensor. Expected: ",
indices.NumElements(),
" Actual: ", input_tensor.shape().dim_size(0)));
// Validate indices and resize output_list.tensors to fit the highest index.
{
size_t list_size = 0;
for (int index = 0; index < indices.NumElements(); ++index) {
const int i = indices.flat<int32>()(index);
OP_REQUIRES(c, i >= 0,
errors::InvalidArgument(
"Indices in TensorListScatter must all be positive."));
if (i >= list_size) {
list_size = i + 1;
}
}
output_list.tensors.resize(list_size, Tensor(DT_INVALID));
}
for (int index = 0; index < indices.NumElements(); ++index) {
const int i = indices.flat<int32>()(index);
OP_REQUIRES(c, i < input_tensor.shape().dim_size(0),
errors::InvalidArgument(
"Trying to scatter index ", i, " from tensor with ",
input_tensor.shape().dim_size(0), " rows."));
Tensor tmp = input_tensor.Slice(i, i + 1);
Tensor tmp = input_tensor.Slice(index, index + 1);
TensorShape tmp_shape = tmp.shape();
tmp_shape.RemoveDim(0);
OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
@ -541,7 +558,7 @@ class TensorListScatter : public OpKernel {
// many small ondes.
aligned.flat<T>().device(c->eigen_device<Device>()) =
tmp.unaligned_flat<T>();
output_list.tensors.push_back(aligned);
std::swap(output_list.tensors[i], aligned);
}
output_tensor->scalar<Variant>()() = std::move(output_list);
}

View File

@ -290,6 +290,47 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
self.evaluate(t)
def testGatherGradWithNonContiguousIndices(self):
with backprop.GradientTape(persistent=True) as tape:
t = constant_op.constant([1.0, 2.0, 3.0])
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
c = constant_op.constant(5.0)
tape.watch(c)
l = list_ops.tensor_list_set_item(l, 1, c)
t = list_ops.tensor_list_gather(l, [1], element_dtype=dtypes.float32)
self.assertAllEqual(self.evaluate(t), [5.0])
s = t[0] * t[0]
dt = tape.gradient(s, c)
self.assertAllEqual(self.evaluate(dt), 10.0)
dl = tape.gradient(t, l)
dl_length = list_ops.tensor_list_length(dl)
self.assertAllEqual(self.evaluate(dl_length), 3)
def testScatterOutputListSize(self):
c0 = constant_op.constant([1.0, 2.0])
l = list_ops.tensor_list_scatter(
c0, [1, 3], ops.convert_to_tensor([], dtype=dtypes.int32))
# TensorListScatter should return a list with size largest index + 1.
self.assertEqual(self.evaluate(list_ops.tensor_list_length(l)), 4)
def testScatterWithInvalidRowsInInputTensorFails(self):
c0 = constant_op.constant([1.0, 2.0])
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Invalid number of rows in input tensor. Expected: 3 Actual: 2"):
l = list_ops.tensor_list_scatter(
c0, [1, 0, 2], ops.convert_to_tensor([], dtype=dtypes.int32))
self.evaluate(l)
def testScatterWithNegativeIndicesFails(self):
c0 = constant_op.constant([1.0, 2.0])
with self.assertRaisesRegexp(
errors.InvalidArgumentError,
"Indices in TensorListScatter must all be positive."):
l = list_ops.tensor_list_scatter(
c0, [-1, -2], ops.convert_to_tensor([], dtype=dtypes.int32))
self.evaluate(l)
def testScatterGrad(self):
with backprop.GradientTape() as tape:
c0 = constant_op.constant([1.0, 2.0])

View File

@ -1359,7 +1359,6 @@ class TensorArrayTest(test.TestCase):
def testSkipEagerTensorArrayEvalEmptyWithDefault(self):
self._testTensorArrayEvalEmptyWithDefault()
@test_util.disable_control_flow_v2("b/117943286")
@test_util.run_v1_only("b/117943489")
def testSkipEagerTensorArrayScatterReadAndGradients(self):
with self.session(use_gpu=True) as session:
@ -1387,8 +1386,8 @@ class TensorArrayTest(test.TestCase):
self.assertAllEqual([10.0, -10.0], read_vals[1])
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
@test_util.disable_control_flow_v2("b/117943286")
@test_util.run_v1_only("b/117943286")
@test_util.disable_control_flow_v2("b/118890905")
@test_util.run_v1_only("b/118890905")
def testTensorArrayWriteGatherAndGradients(self):
with self.session(use_gpu=True) as session:
ta = tensor_array_ops.TensorArray(

View File

@ -200,10 +200,16 @@ def _TensorListResizeGrad(op, dlist):
@ops.RegisterGradient("TensorListGather")
def _TensorListGatherGrad(op, dtensor):
_, indices = op.inputs
return gen_list_ops.tensor_list_scatter(
tensor=dtensor, indices=indices,
element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)), None
input_list, indices = op.inputs
dlist = gen_list_ops.tensor_list_scatter(
tensor=dtensor,
indices=indices,
element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32))
# TensorListScatter returns a list with size `max(indices) + 1`
# so we manually resize it to match the size of the input list.
input_list_size = gen_list_ops.tensor_list_length(input_list)
dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size)
return dlist, None
@ops.RegisterGradient("TensorListScatter")