Allows TensorListScatter to scatter at non-contiguous indices to make it consistent with TensorArray.scatter.
PiperOrigin-RevId: 227874653
This commit is contained in:
parent
07816cc1e1
commit
daa75aff18
@ -521,14 +521,31 @@ class TensorListScatter : public OpKernel {
|
||||
"Specified a list with shape ", element_shape.DebugString(),
|
||||
" from a tensor with shape ", output_shape.DebugString()));
|
||||
output_list.element_shape = element_shape;
|
||||
output_list.tensors.reserve(indices.NumElements());
|
||||
|
||||
OP_REQUIRES(c, indices.NumElements() == input_tensor.shape().dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"Invalid number of rows in input tensor. Expected: ",
|
||||
indices.NumElements(),
|
||||
" Actual: ", input_tensor.shape().dim_size(0)));
|
||||
|
||||
// Validate indices and resize output_list.tensors to fit the highest index.
|
||||
{
|
||||
size_t list_size = 0;
|
||||
for (int index = 0; index < indices.NumElements(); ++index) {
|
||||
const int i = indices.flat<int32>()(index);
|
||||
OP_REQUIRES(c, i >= 0,
|
||||
errors::InvalidArgument(
|
||||
"Indices in TensorListScatter must all be positive."));
|
||||
if (i >= list_size) {
|
||||
list_size = i + 1;
|
||||
}
|
||||
}
|
||||
output_list.tensors.resize(list_size, Tensor(DT_INVALID));
|
||||
}
|
||||
|
||||
for (int index = 0; index < indices.NumElements(); ++index) {
|
||||
const int i = indices.flat<int32>()(index);
|
||||
OP_REQUIRES(c, i < input_tensor.shape().dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"Trying to scatter index ", i, " from tensor with ",
|
||||
input_tensor.shape().dim_size(0), " rows."));
|
||||
Tensor tmp = input_tensor.Slice(i, i + 1);
|
||||
Tensor tmp = input_tensor.Slice(index, index + 1);
|
||||
TensorShape tmp_shape = tmp.shape();
|
||||
tmp_shape.RemoveDim(0);
|
||||
OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
|
||||
@ -541,7 +558,7 @@ class TensorListScatter : public OpKernel {
|
||||
// many small ondes.
|
||||
aligned.flat<T>().device(c->eigen_device<Device>()) =
|
||||
tmp.unaligned_flat<T>();
|
||||
output_list.tensors.push_back(aligned);
|
||||
std::swap(output_list.tensors[i], aligned);
|
||||
}
|
||||
output_tensor->scalar<Variant>()() = std::move(output_list);
|
||||
}
|
||||
|
@ -290,6 +290,47 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
t = list_ops.tensor_list_gather(l, [], element_dtype=dtypes.float32)
|
||||
self.evaluate(t)
|
||||
|
||||
def testGatherGradWithNonContiguousIndices(self):
|
||||
with backprop.GradientTape(persistent=True) as tape:
|
||||
t = constant_op.constant([1.0, 2.0, 3.0])
|
||||
l = list_ops.tensor_list_from_tensor(t, element_shape=[])
|
||||
c = constant_op.constant(5.0)
|
||||
tape.watch(c)
|
||||
l = list_ops.tensor_list_set_item(l, 1, c)
|
||||
t = list_ops.tensor_list_gather(l, [1], element_dtype=dtypes.float32)
|
||||
self.assertAllEqual(self.evaluate(t), [5.0])
|
||||
s = t[0] * t[0]
|
||||
dt = tape.gradient(s, c)
|
||||
self.assertAllEqual(self.evaluate(dt), 10.0)
|
||||
dl = tape.gradient(t, l)
|
||||
dl_length = list_ops.tensor_list_length(dl)
|
||||
self.assertAllEqual(self.evaluate(dl_length), 3)
|
||||
|
||||
def testScatterOutputListSize(self):
|
||||
c0 = constant_op.constant([1.0, 2.0])
|
||||
l = list_ops.tensor_list_scatter(
|
||||
c0, [1, 3], ops.convert_to_tensor([], dtype=dtypes.int32))
|
||||
# TensorListScatter should return a list with size largest index + 1.
|
||||
self.assertEqual(self.evaluate(list_ops.tensor_list_length(l)), 4)
|
||||
|
||||
def testScatterWithInvalidRowsInInputTensorFails(self):
|
||||
c0 = constant_op.constant([1.0, 2.0])
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Invalid number of rows in input tensor. Expected: 3 Actual: 2"):
|
||||
l = list_ops.tensor_list_scatter(
|
||||
c0, [1, 0, 2], ops.convert_to_tensor([], dtype=dtypes.int32))
|
||||
self.evaluate(l)
|
||||
|
||||
def testScatterWithNegativeIndicesFails(self):
|
||||
c0 = constant_op.constant([1.0, 2.0])
|
||||
with self.assertRaisesRegexp(
|
||||
errors.InvalidArgumentError,
|
||||
"Indices in TensorListScatter must all be positive."):
|
||||
l = list_ops.tensor_list_scatter(
|
||||
c0, [-1, -2], ops.convert_to_tensor([], dtype=dtypes.int32))
|
||||
self.evaluate(l)
|
||||
|
||||
def testScatterGrad(self):
|
||||
with backprop.GradientTape() as tape:
|
||||
c0 = constant_op.constant([1.0, 2.0])
|
||||
|
@ -1359,7 +1359,6 @@ class TensorArrayTest(test.TestCase):
|
||||
def testSkipEagerTensorArrayEvalEmptyWithDefault(self):
|
||||
self._testTensorArrayEvalEmptyWithDefault()
|
||||
|
||||
@test_util.disable_control_flow_v2("b/117943286")
|
||||
@test_util.run_v1_only("b/117943489")
|
||||
def testSkipEagerTensorArrayScatterReadAndGradients(self):
|
||||
with self.session(use_gpu=True) as session:
|
||||
@ -1387,8 +1386,8 @@ class TensorArrayTest(test.TestCase):
|
||||
self.assertAllEqual([10.0, -10.0], read_vals[1])
|
||||
self.assertAllEqual([[2.0, 3.0], [4.0, 5.0]], grad_vals[0])
|
||||
|
||||
@test_util.disable_control_flow_v2("b/117943286")
|
||||
@test_util.run_v1_only("b/117943286")
|
||||
@test_util.disable_control_flow_v2("b/118890905")
|
||||
@test_util.run_v1_only("b/118890905")
|
||||
def testTensorArrayWriteGatherAndGradients(self):
|
||||
with self.session(use_gpu=True) as session:
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
|
@ -200,10 +200,16 @@ def _TensorListResizeGrad(op, dlist):
|
||||
|
||||
@ops.RegisterGradient("TensorListGather")
|
||||
def _TensorListGatherGrad(op, dtensor):
|
||||
_, indices = op.inputs
|
||||
return gen_list_ops.tensor_list_scatter(
|
||||
tensor=dtensor, indices=indices,
|
||||
element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32)), None
|
||||
input_list, indices = op.inputs
|
||||
dlist = gen_list_ops.tensor_list_scatter(
|
||||
tensor=dtensor,
|
||||
indices=indices,
|
||||
element_shape=ops.convert_to_tensor(-1, dtype=dtypes.int32))
|
||||
# TensorListScatter returns a list with size `max(indices) + 1`
|
||||
# so we manually resize it to match the size of the input list.
|
||||
input_list_size = gen_list_ops.tensor_list_length(input_list)
|
||||
dlist = gen_list_ops.tensor_list_resize(dlist, input_list_size)
|
||||
return dlist, None
|
||||
|
||||
|
||||
@ops.RegisterGradient("TensorListScatter")
|
||||
|
Loading…
Reference in New Issue
Block a user