Use the element_shape input when doing shape inference for TensorListGetItem, TensorListStack and TensorListGather.

This avoids the need to manually set the shape of the output tensor in TensorArray.read/stack/gather.
This is necessary to make shape inference correctly work in a deserialized graph containing v2 TensorArray ops, e.g. when building the gradient of tf.cond/while_loop.

PiperOrigin-RevId: 243109816
This commit is contained in:
Saurabh Saxena 2019-04-11 11:38:18 -07:00 committed by TensorFlower Gardener
parent 80a47e09c9
commit 826a2450d1
5 changed files with 88 additions and 10 deletions

View File

@ -215,7 +215,7 @@ REGISTER_OP("TensorListStack")
return errors::InvalidArgument( return errors::InvalidArgument(
"Trying to read from list with wrong element dtype. List has " "Trying to read from list with wrong element dtype. List has "
"type ", "type ",
DataTypeString(list_shape_type.dtype), " but expectec type ", DataTypeString(list_shape_type.dtype), " but expected type ",
DataTypeString(element_dtype)); DataTypeString(element_dtype));
} }
shape_inference::ShapeHandle ignored; shape_inference::ShapeHandle ignored;
@ -223,6 +223,11 @@ REGISTER_OP("TensorListStack")
c->Merge(element_shape, list_shape_type.shape, &ignored)); c->Merge(element_shape, list_shape_type.shape, &ignored));
element_shape = list_shape_type.shape; element_shape = list_shape_type.shape;
} }
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
1, &element_shape_input));
TF_RETURN_IF_ERROR(
c->Merge(element_shape, element_shape_input, &element_shape));
int expected_num_elements = -1; int expected_num_elements = -1;
TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements)); TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements));
shape_inference::ShapeHandle num_elements; shape_inference::ShapeHandle num_elements;
@ -418,6 +423,11 @@ REGISTER_OP("TensorListGetItem")
DataTypeString(list_shape_type.dtype)); DataTypeString(list_shape_type.dtype));
} }
} }
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape_input));
TF_RETURN_IF_ERROR(
c->Merge(element_shape, element_shape_input, &element_shape));
c->set_output(0, element_shape); c->set_output(0, element_shape);
return Status::OK(); return Status::OK();
}); });
@ -486,6 +496,11 @@ REGISTER_OP("TensorListGather")
DataTypeString(list_shape_type.dtype)); DataTypeString(list_shape_type.dtype));
} }
} }
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
2, &element_shape_input));
TF_RETURN_IF_ERROR(
c->Merge(element_shape, element_shape_input, &element_shape));
shape_inference::ShapeHandle out; shape_inference::ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out)); TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
c->set_output(0, out); c->set_output(0, out);

View File

@ -2652,6 +2652,7 @@ cuda_py_test(
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
"//tensorflow/python:math_ops", "//tensorflow/python:math_ops",
"//tensorflow/python:nn_grad", "//tensorflow/python:nn_grad",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:tensor_array_grad", "//tensorflow/python:tensor_array_grad",
"//tensorflow/python:tensor_array_ops", "//tensorflow/python:tensor_array_ops",
@ -2661,6 +2662,7 @@ cuda_py_test(
"//tensorflow/python:while_v2", "//tensorflow/python:while_v2",
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:def_function",
], ],
flaky = 1, # create_local_cluster sometimes times out. flaky = 1, # create_local_cluster sometimes times out.
shard_count = 10, shard_count = 10,

View File

@ -287,6 +287,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
element_dtype=dtypes.float32, element_shape=None, num_elements=3) element_dtype=dtypes.float32, element_shape=None, num_elements=3)
t = gen_list_ops.tensor_list_stack( t = gen_list_ops.tensor_list_stack(
l, element_dtype=dtypes.float32, element_shape=[]) l, element_dtype=dtypes.float32, element_shape=[])
if context.executing_eagerly():
self.assertEqual(t.shape.as_list(), [3])
else:
self.assertEqual(t.shape.as_list(), [None])
self.assertAllEqual(self.evaluate(t), np.zeros((3,))) self.assertAllEqual(self.evaluate(t), np.zeros((3,)))
@parameterized.named_parameters(("NoMaxNumElements", None), @parameterized.named_parameters(("NoMaxNumElements", None),
@ -453,6 +457,7 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
element_dtype=dtypes.float32, element_shape=None, num_elements=3) element_dtype=dtypes.float32, element_shape=None, num_elements=3)
t = gen_list_ops.tensor_list_gather( t = gen_list_ops.tensor_list_gather(
l, [0, 1, 2], element_dtype=dtypes.float32, element_shape=[]) l, [0, 1, 2], element_dtype=dtypes.float32, element_shape=[])
self.assertEqual(t.shape.as_list(), [3])
self.assertAllEqual(self.evaluate(t), np.zeros((3,))) self.assertAllEqual(self.evaluate(t), np.zeros((3,)))
def testScatterOutputListSize(self): def testScatterOutputListSize(self):
@ -607,6 +612,8 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l, 0, element_shape=[], element_dtype=dtypes.float32) l, 0, element_shape=[], element_dtype=dtypes.float32)
e1 = gen_list_ops.tensor_list_get_item( e1 = gen_list_ops.tensor_list_get_item(
l, 1, element_shape=[2, 3], element_dtype=dtypes.float32) l, 1, element_shape=[2, 3], element_dtype=dtypes.float32)
self.assertEqual(e0.shape.as_list(), [])
self.assertEqual(e1.shape.as_list(), [2, 3])
self.assertEqual(self.evaluate(e0), 0.) self.assertEqual(self.evaluate(e0), 0.)
self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3))) self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3)))
@ -628,9 +635,16 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
l = list_ops.tensor_list_reserve( l = list_ops.tensor_list_reserve(
element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3) element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3)
with self.assertRaisesRegexp(
errors.InvalidArgumentError, # In eager mode the shape mismatch is caught in the TensorListGetItem
r"Incompatible shapes during merge: \[1,3\] vs. \[\?,2\]"): # kernel which raises an InvalidArgumentError.
# In graph mode the shape mismatch is caught in the C++ shape inference
# which raises a ValueError.
if context.executing_eagerly():
error_type = errors.InvalidArgumentError
else:
error_type = ValueError
with self.assertRaisesRegexp(error_type, r"shapes"):
e0 = gen_list_ops.tensor_list_get_item( e0 = gen_list_ops.tensor_list_get_item(
l, 0, element_dtype=dtypes.float32, element_shape=[1, 3]) l, 0, element_dtype=dtypes.float32, element_shape=[1, 3])
self.evaluate(e0) self.evaluate(e0)

View File

@ -24,11 +24,13 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as session_lib from tensorflow.python.client import session as session_lib
from tensorflow.python.eager import backprop from tensorflow.python.eager import backprop
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
@ -1345,6 +1347,57 @@ class TensorArrayTest(test.TestCase):
grad = gradients_impl.gradients(ys=[r], xs=[x]) grad = gradients_impl.gradients(ys=[r], xs=[x])
self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0]) self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
def testStackShape(self):
@def_function.function
def ta_stack():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
x = constant_op.constant([1.0, 2.0, 3.0])
ta = ta.write(0, x)
t = ta.stack()
self.assertEqual(t.shape.as_list(), [None, 3])
return t
ta_stack()
def testReadShape(self):
@def_function.function
def ta_read():
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
x = constant_op.constant([1.0, 2.0, 3.0])
ta = ta.write(0, x)
t = ta.read(0)
self.assertEqual(t.shape.as_list(), [3])
return t
ta_read()
def testGatherShape(self):
def ta_gather(indices):
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
x = constant_op.constant([1.0, 2.0, 3.0])
ta = ta.write(0, x)
t = ta.gather(indices)
self.assertEqual(t.shape.as_list(), [first_dim, 3])
return t
# This propagates shape of `indices` when compiling ta_gather.
ta_gather_with_known_indices_shape = def_function.function(ta_gather)
first_dim = 1
ta_gather_with_known_indices_shape([0])
# Here were force the shape of `indices` to be [None] during ta_gather's
# compilation.
ta_gather_with_unknown_indices_shape = def_function.function(
ta_gather,
input_signature=[
tensor_spec.TensorSpec(dtype=dtypes.int32, shape=[None])
])
first_dim = None
ta_gather_with_unknown_indices_shape([0])
def _testTensorArrayEvalEmpty(self): def _testTensorArrayEvalEmpty(self):
with self.cached_session(use_gpu=True): with self.cached_session(use_gpu=True):
ta = tensor_array_ops.TensorArray( ta = tensor_array_ops.TensorArray(

View File

@ -577,8 +577,6 @@ class _GraphTensorArrayV2(object):
element_dtype=self._dtype, element_dtype=self._dtype,
element_shape=element_shape, element_shape=element_shape,
name=name) name=name)
if self._element_shape:
value.set_shape(self._element_shape[0].dims)
return value return value
@tf_should_use.should_use_result @tf_should_use.should_use_result
@ -610,8 +608,6 @@ class _GraphTensorArrayV2(object):
input_handle=self._flow, input_handle=self._flow,
element_dtype=self._dtype, element_dtype=self._dtype,
element_shape=element_shape) element_shape=element_shape)
if self._element_shape and self._element_shape[0].dims is not None:
value.set_shape([None] + self._element_shape[0].dims)
return value return value
def gather(self, indices, name=None): def gather(self, indices, name=None):
@ -626,8 +622,6 @@ class _GraphTensorArrayV2(object):
element_dtype=self._dtype, element_dtype=self._dtype,
element_shape=element_shape, element_shape=element_shape,
name=name) name=name)
if self._element_shape and self._element_shape[0].dims is not None:
value.set_shape([None] + self._element_shape[0].dims)
return value return value
def concat(self, name=None): def concat(self, name=None):