diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc index 123ffc493a9..7a0ccb11f1d 100644 --- a/tensorflow/core/ops/list_ops.cc +++ b/tensorflow/core/ops/list_ops.cc @@ -215,7 +215,7 @@ REGISTER_OP("TensorListStack") return errors::InvalidArgument( "Trying to read from list with wrong element dtype. List has " "type ", - DataTypeString(list_shape_type.dtype), " but expectec type ", + DataTypeString(list_shape_type.dtype), " but expected type ", DataTypeString(element_dtype)); } shape_inference::ShapeHandle ignored; @@ -223,6 +223,11 @@ REGISTER_OP("TensorListStack") c->Merge(element_shape, list_shape_type.shape, &ignored)); element_shape = list_shape_type.shape; } + shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 1, &element_shape_input)); + TF_RETURN_IF_ERROR( + c->Merge(element_shape, element_shape_input, &element_shape)); int expected_num_elements = -1; TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements)); shape_inference::ShapeHandle num_elements; @@ -418,6 +423,11 @@ REGISTER_OP("TensorListGetItem") DataTypeString(list_shape_type.dtype)); } } + shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 2, &element_shape_input)); + TF_RETURN_IF_ERROR( + c->Merge(element_shape, element_shape_input, &element_shape)); c->set_output(0, element_shape); return Status::OK(); }); @@ -486,6 +496,11 @@ REGISTER_OP("TensorListGather") DataTypeString(list_shape_type.dtype)); } } + shape_inference::ShapeHandle element_shape_input = c->UnknownShape(); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape( + 2, &element_shape_input)); + TF_RETURN_IF_ERROR( + c->Merge(element_shape, element_shape_input, &element_shape)); shape_inference::ShapeHandle out; TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out)); c->set_output(0, out); diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD index 1011dc0835d..d64330e6b93 100644 --- a/tensorflow/python/kernel_tests/BUILD +++ b/tensorflow/python/kernel_tests/BUILD @@ -2652,6 +2652,7 @@ cuda_py_test( "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_grad", + "//tensorflow/python:tensor_spec", "//tensorflow/python:training", "//tensorflow/python:tensor_array_grad", "//tensorflow/python:tensor_array_ops", @@ -2661,6 +2662,7 @@ cuda_py_test( "//tensorflow/python:while_v2", "//tensorflow/python/eager:backprop", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:def_function", ], flaky = 1, # create_local_cluster sometimes times out. shard_count = 10, diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py index c7e5621b6d1..f5fda3c9a46 100644 --- a/tensorflow/python/kernel_tests/list_ops_test.py +++ b/tensorflow/python/kernel_tests/list_ops_test.py @@ -287,6 +287,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): element_dtype=dtypes.float32, element_shape=None, num_elements=3) t = gen_list_ops.tensor_list_stack( l, element_dtype=dtypes.float32, element_shape=[]) + if context.executing_eagerly(): + self.assertEqual(t.shape.as_list(), [3]) + else: + self.assertEqual(t.shape.as_list(), [None]) self.assertAllEqual(self.evaluate(t), np.zeros((3,))) @parameterized.named_parameters(("NoMaxNumElements", None), @@ -453,6 +457,7 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): element_dtype=dtypes.float32, element_shape=None, num_elements=3) t = gen_list_ops.tensor_list_gather( l, [0, 1, 2], element_dtype=dtypes.float32, element_shape=[]) + self.assertEqual(t.shape.as_list(), [3]) self.assertAllEqual(self.evaluate(t), np.zeros((3,))) def testScatterOutputListSize(self): @@ -607,6 +612,8 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): l, 0, element_shape=[], element_dtype=dtypes.float32) e1 = gen_list_ops.tensor_list_get_item( l, 1, element_shape=[2, 3], element_dtype=dtypes.float32) + self.assertEqual(e0.shape.as_list(), []) + self.assertEqual(e1.shape.as_list(), [2, 3]) self.assertEqual(self.evaluate(e0), 0.) self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3))) @@ -628,9 +635,16 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase): l = list_ops.tensor_list_reserve( element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3) - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - r"Incompatible shapes during merge: \[1,3\] vs. \[\?,2\]"): + + # In eager mode the shape mismatch is caught in the TensorListGetItem + # kernel which raises an InvalidArgumentError. + # In graph mode the shape mismatch is caught in the C++ shape inference + # which raises a ValueError. + if context.executing_eagerly(): + error_type = errors.InvalidArgumentError + else: + error_type = ValueError + with self.assertRaisesRegexp(error_type, r"shapes"): e0 = gen_list_ops.tensor_list_get_item( l, 0, element_dtype=dtypes.float32, element_shape=[1, 3]) self.evaluate(e0) diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 5bae6c1ffa7..d66357b8380 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -24,11 +24,13 @@ from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.eager import backprop from tensorflow.python.eager import context +from tensorflow.python.eager import def_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -1345,6 +1347,57 @@ class TensorArrayTest(test.TestCase): grad = gradients_impl.gradients(ys=[r], xs=[x]) self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0]) + def testStackShape(self): + + @def_function.function + def ta_stack(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3) + x = constant_op.constant([1.0, 2.0, 3.0]) + ta = ta.write(0, x) + t = ta.stack() + self.assertEqual(t.shape.as_list(), [None, 3]) + return t + + ta_stack() + + def testReadShape(self): + + @def_function.function + def ta_read(): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3) + x = constant_op.constant([1.0, 2.0, 3.0]) + ta = ta.write(0, x) + t = ta.read(0) + self.assertEqual(t.shape.as_list(), [3]) + return t + + ta_read() + + def testGatherShape(self): + + def ta_gather(indices): + ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3) + x = constant_op.constant([1.0, 2.0, 3.0]) + ta = ta.write(0, x) + t = ta.gather(indices) + self.assertEqual(t.shape.as_list(), [first_dim, 3]) + return t + + # This propagates shape of `indices` when compiling ta_gather. + ta_gather_with_known_indices_shape = def_function.function(ta_gather) + first_dim = 1 + ta_gather_with_known_indices_shape([0]) + + # Here were force the shape of `indices` to be [None] during ta_gather's + # compilation. + ta_gather_with_unknown_indices_shape = def_function.function( + ta_gather, + input_signature=[ + tensor_spec.TensorSpec(dtype=dtypes.int32, shape=[None]) + ]) + first_dim = None + ta_gather_with_unknown_indices_shape([0]) + def _testTensorArrayEvalEmpty(self): with self.cached_session(use_gpu=True): ta = tensor_array_ops.TensorArray( diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index dd3f9de8899..fb90c122ce1 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -577,8 +577,6 @@ class _GraphTensorArrayV2(object): element_dtype=self._dtype, element_shape=element_shape, name=name) - if self._element_shape: - value.set_shape(self._element_shape[0].dims) return value @tf_should_use.should_use_result @@ -610,8 +608,6 @@ class _GraphTensorArrayV2(object): input_handle=self._flow, element_dtype=self._dtype, element_shape=element_shape) - if self._element_shape and self._element_shape[0].dims is not None: - value.set_shape([None] + self._element_shape[0].dims) return value def gather(self, indices, name=None): @@ -626,8 +622,6 @@ class _GraphTensorArrayV2(object): element_dtype=self._dtype, element_shape=element_shape, name=name) - if self._element_shape and self._element_shape[0].dims is not None: - value.set_shape([None] + self._element_shape[0].dims) return value def concat(self, name=None):