Use the element_shape input when doing shape inference for TensorListGetItem, TensorListStack and TensorListGather.
This avoids the need to manually set the shape of the output tensor in TensorArray.read/stack/gather. This is necessary to make shape inference correctly work in a deserialized graph containing v2 TensorArray ops, e.g. when building the gradient of tf.cond/while_loop. PiperOrigin-RevId: 243109816
This commit is contained in:
parent
80a47e09c9
commit
826a2450d1
@ -215,7 +215,7 @@ REGISTER_OP("TensorListStack")
|
|||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Trying to read from list with wrong element dtype. List has "
|
"Trying to read from list with wrong element dtype. List has "
|
||||||
"type ",
|
"type ",
|
||||||
DataTypeString(list_shape_type.dtype), " but expectec type ",
|
DataTypeString(list_shape_type.dtype), " but expected type ",
|
||||||
DataTypeString(element_dtype));
|
DataTypeString(element_dtype));
|
||||||
}
|
}
|
||||||
shape_inference::ShapeHandle ignored;
|
shape_inference::ShapeHandle ignored;
|
||||||
@ -223,6 +223,11 @@ REGISTER_OP("TensorListStack")
|
|||||||
c->Merge(element_shape, list_shape_type.shape, &ignored));
|
c->Merge(element_shape, list_shape_type.shape, &ignored));
|
||||||
element_shape = list_shape_type.shape;
|
element_shape = list_shape_type.shape;
|
||||||
}
|
}
|
||||||
|
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||||
|
1, &element_shape_input));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(element_shape, element_shape_input, &element_shape));
|
||||||
int expected_num_elements = -1;
|
int expected_num_elements = -1;
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements));
|
TF_RETURN_IF_ERROR(c->GetAttr("num_elements", &expected_num_elements));
|
||||||
shape_inference::ShapeHandle num_elements;
|
shape_inference::ShapeHandle num_elements;
|
||||||
@ -418,6 +423,11 @@ REGISTER_OP("TensorListGetItem")
|
|||||||
DataTypeString(list_shape_type.dtype));
|
DataTypeString(list_shape_type.dtype));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||||
|
2, &element_shape_input));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(element_shape, element_shape_input, &element_shape));
|
||||||
c->set_output(0, element_shape);
|
c->set_output(0, element_shape);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
});
|
});
|
||||||
@ -486,6 +496,11 @@ REGISTER_OP("TensorListGather")
|
|||||||
DataTypeString(list_shape_type.dtype));
|
DataTypeString(list_shape_type.dtype));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
shape_inference::ShapeHandle element_shape_input = c->UnknownShape();
|
||||||
|
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
|
||||||
|
2, &element_shape_input));
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
c->Merge(element_shape, element_shape_input, &element_shape));
|
||||||
shape_inference::ShapeHandle out;
|
shape_inference::ShapeHandle out;
|
||||||
TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
|
TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
|
||||||
c->set_output(0, out);
|
c->set_output(0, out);
|
||||||
|
@ -2652,6 +2652,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:init_ops",
|
"//tensorflow/python:init_ops",
|
||||||
"//tensorflow/python:math_ops",
|
"//tensorflow/python:math_ops",
|
||||||
"//tensorflow/python:nn_grad",
|
"//tensorflow/python:nn_grad",
|
||||||
|
"//tensorflow/python:tensor_spec",
|
||||||
"//tensorflow/python:training",
|
"//tensorflow/python:training",
|
||||||
"//tensorflow/python:tensor_array_grad",
|
"//tensorflow/python:tensor_array_grad",
|
||||||
"//tensorflow/python:tensor_array_ops",
|
"//tensorflow/python:tensor_array_ops",
|
||||||
@ -2661,6 +2662,7 @@ cuda_py_test(
|
|||||||
"//tensorflow/python:while_v2",
|
"//tensorflow/python:while_v2",
|
||||||
"//tensorflow/python/eager:backprop",
|
"//tensorflow/python/eager:backprop",
|
||||||
"//tensorflow/python/eager:context",
|
"//tensorflow/python/eager:context",
|
||||||
|
"//tensorflow/python/eager:def_function",
|
||||||
],
|
],
|
||||||
flaky = 1, # create_local_cluster sometimes times out.
|
flaky = 1, # create_local_cluster sometimes times out.
|
||||||
shard_count = 10,
|
shard_count = 10,
|
||||||
|
@ -287,6 +287,10 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
element_dtype=dtypes.float32, element_shape=None, num_elements=3)
|
element_dtype=dtypes.float32, element_shape=None, num_elements=3)
|
||||||
t = gen_list_ops.tensor_list_stack(
|
t = gen_list_ops.tensor_list_stack(
|
||||||
l, element_dtype=dtypes.float32, element_shape=[])
|
l, element_dtype=dtypes.float32, element_shape=[])
|
||||||
|
if context.executing_eagerly():
|
||||||
|
self.assertEqual(t.shape.as_list(), [3])
|
||||||
|
else:
|
||||||
|
self.assertEqual(t.shape.as_list(), [None])
|
||||||
self.assertAllEqual(self.evaluate(t), np.zeros((3,)))
|
self.assertAllEqual(self.evaluate(t), np.zeros((3,)))
|
||||||
|
|
||||||
@parameterized.named_parameters(("NoMaxNumElements", None),
|
@parameterized.named_parameters(("NoMaxNumElements", None),
|
||||||
@ -453,6 +457,7 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
element_dtype=dtypes.float32, element_shape=None, num_elements=3)
|
element_dtype=dtypes.float32, element_shape=None, num_elements=3)
|
||||||
t = gen_list_ops.tensor_list_gather(
|
t = gen_list_ops.tensor_list_gather(
|
||||||
l, [0, 1, 2], element_dtype=dtypes.float32, element_shape=[])
|
l, [0, 1, 2], element_dtype=dtypes.float32, element_shape=[])
|
||||||
|
self.assertEqual(t.shape.as_list(), [3])
|
||||||
self.assertAllEqual(self.evaluate(t), np.zeros((3,)))
|
self.assertAllEqual(self.evaluate(t), np.zeros((3,)))
|
||||||
|
|
||||||
def testScatterOutputListSize(self):
|
def testScatterOutputListSize(self):
|
||||||
@ -607,6 +612,8 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
l, 0, element_shape=[], element_dtype=dtypes.float32)
|
l, 0, element_shape=[], element_dtype=dtypes.float32)
|
||||||
e1 = gen_list_ops.tensor_list_get_item(
|
e1 = gen_list_ops.tensor_list_get_item(
|
||||||
l, 1, element_shape=[2, 3], element_dtype=dtypes.float32)
|
l, 1, element_shape=[2, 3], element_dtype=dtypes.float32)
|
||||||
|
self.assertEqual(e0.shape.as_list(), [])
|
||||||
|
self.assertEqual(e1.shape.as_list(), [2, 3])
|
||||||
self.assertEqual(self.evaluate(e0), 0.)
|
self.assertEqual(self.evaluate(e0), 0.)
|
||||||
self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3)))
|
self.assertAllEqual(self.evaluate(e1), np.zeros((2, 3)))
|
||||||
|
|
||||||
@ -628,9 +635,16 @@ class ListOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|||||||
|
|
||||||
l = list_ops.tensor_list_reserve(
|
l = list_ops.tensor_list_reserve(
|
||||||
element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3)
|
element_dtype=dtypes.float32, element_shape=[None, 2], num_elements=3)
|
||||||
with self.assertRaisesRegexp(
|
|
||||||
errors.InvalidArgumentError,
|
# In eager mode the shape mismatch is caught in the TensorListGetItem
|
||||||
r"Incompatible shapes during merge: \[1,3\] vs. \[\?,2\]"):
|
# kernel which raises an InvalidArgumentError.
|
||||||
|
# In graph mode the shape mismatch is caught in the C++ shape inference
|
||||||
|
# which raises a ValueError.
|
||||||
|
if context.executing_eagerly():
|
||||||
|
error_type = errors.InvalidArgumentError
|
||||||
|
else:
|
||||||
|
error_type = ValueError
|
||||||
|
with self.assertRaisesRegexp(error_type, r"shapes"):
|
||||||
e0 = gen_list_ops.tensor_list_get_item(
|
e0 = gen_list_ops.tensor_list_get_item(
|
||||||
l, 0, element_dtype=dtypes.float32, element_shape=[1, 3])
|
l, 0, element_dtype=dtypes.float32, element_shape=[1, 3])
|
||||||
self.evaluate(e0)
|
self.evaluate(e0)
|
||||||
|
@ -24,11 +24,13 @@ from tensorflow.core.protobuf import config_pb2
|
|||||||
from tensorflow.python.client import session as session_lib
|
from tensorflow.python.client import session as session_lib
|
||||||
from tensorflow.python.eager import backprop
|
from tensorflow.python.eager import backprop
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
|
from tensorflow.python.framework import tensor_spec
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
@ -1345,6 +1347,57 @@ class TensorArrayTest(test.TestCase):
|
|||||||
grad = gradients_impl.gradients(ys=[r], xs=[x])
|
grad = gradients_impl.gradients(ys=[r], xs=[x])
|
||||||
self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
|
self.assertAllEqual(np.array([1.0, 1.0, 1.0]), self.evaluate(grad)[0])
|
||||||
|
|
||||||
|
def testStackShape(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def ta_stack():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
|
||||||
|
x = constant_op.constant([1.0, 2.0, 3.0])
|
||||||
|
ta = ta.write(0, x)
|
||||||
|
t = ta.stack()
|
||||||
|
self.assertEqual(t.shape.as_list(), [None, 3])
|
||||||
|
return t
|
||||||
|
|
||||||
|
ta_stack()
|
||||||
|
|
||||||
|
def testReadShape(self):
|
||||||
|
|
||||||
|
@def_function.function
|
||||||
|
def ta_read():
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
|
||||||
|
x = constant_op.constant([1.0, 2.0, 3.0])
|
||||||
|
ta = ta.write(0, x)
|
||||||
|
t = ta.read(0)
|
||||||
|
self.assertEqual(t.shape.as_list(), [3])
|
||||||
|
return t
|
||||||
|
|
||||||
|
ta_read()
|
||||||
|
|
||||||
|
def testGatherShape(self):
|
||||||
|
|
||||||
|
def ta_gather(indices):
|
||||||
|
ta = tensor_array_ops.TensorArray(dtype=dtypes.float32, size=3)
|
||||||
|
x = constant_op.constant([1.0, 2.0, 3.0])
|
||||||
|
ta = ta.write(0, x)
|
||||||
|
t = ta.gather(indices)
|
||||||
|
self.assertEqual(t.shape.as_list(), [first_dim, 3])
|
||||||
|
return t
|
||||||
|
|
||||||
|
# This propagates shape of `indices` when compiling ta_gather.
|
||||||
|
ta_gather_with_known_indices_shape = def_function.function(ta_gather)
|
||||||
|
first_dim = 1
|
||||||
|
ta_gather_with_known_indices_shape([0])
|
||||||
|
|
||||||
|
# Here were force the shape of `indices` to be [None] during ta_gather's
|
||||||
|
# compilation.
|
||||||
|
ta_gather_with_unknown_indices_shape = def_function.function(
|
||||||
|
ta_gather,
|
||||||
|
input_signature=[
|
||||||
|
tensor_spec.TensorSpec(dtype=dtypes.int32, shape=[None])
|
||||||
|
])
|
||||||
|
first_dim = None
|
||||||
|
ta_gather_with_unknown_indices_shape([0])
|
||||||
|
|
||||||
def _testTensorArrayEvalEmpty(self):
|
def _testTensorArrayEvalEmpty(self):
|
||||||
with self.cached_session(use_gpu=True):
|
with self.cached_session(use_gpu=True):
|
||||||
ta = tensor_array_ops.TensorArray(
|
ta = tensor_array_ops.TensorArray(
|
||||||
|
@ -577,8 +577,6 @@ class _GraphTensorArrayV2(object):
|
|||||||
element_dtype=self._dtype,
|
element_dtype=self._dtype,
|
||||||
element_shape=element_shape,
|
element_shape=element_shape,
|
||||||
name=name)
|
name=name)
|
||||||
if self._element_shape:
|
|
||||||
value.set_shape(self._element_shape[0].dims)
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@tf_should_use.should_use_result
|
@tf_should_use.should_use_result
|
||||||
@ -610,8 +608,6 @@ class _GraphTensorArrayV2(object):
|
|||||||
input_handle=self._flow,
|
input_handle=self._flow,
|
||||||
element_dtype=self._dtype,
|
element_dtype=self._dtype,
|
||||||
element_shape=element_shape)
|
element_shape=element_shape)
|
||||||
if self._element_shape and self._element_shape[0].dims is not None:
|
|
||||||
value.set_shape([None] + self._element_shape[0].dims)
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def gather(self, indices, name=None):
|
def gather(self, indices, name=None):
|
||||||
@ -626,8 +622,6 @@ class _GraphTensorArrayV2(object):
|
|||||||
element_dtype=self._dtype,
|
element_dtype=self._dtype,
|
||||||
element_shape=element_shape,
|
element_shape=element_shape,
|
||||||
name=name)
|
name=name)
|
||||||
if self._element_shape and self._element_shape[0].dims is not None:
|
|
||||||
value.set_shape([None] + self._element_shape[0].dims)
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def concat(self, name=None):
|
def concat(self, name=None):
|
||||||
|
Loading…
Reference in New Issue
Block a user