Add "element_shape" attr to TensorListConcat and use it in _GraphTensorArrayV2.

When using while_v2, this change prevents errors like this:
  All except the first dimension must be fully defined when concating an empty tensor list. element_shape: <unknown>

Previously, TensorListConcat op could only determine the element size
from the TensorList object at runtime. In the case of a while loop
that executes zero times, it wouldn't be able to determine the size
since no element was ever seen by the runtime. However, we may have
determined the element size during graph construction, so we use the
attribute to tell the runtime what the element size is.

Note that the original TensorArray implementation already does this.

PiperOrigin-RevId: 227750868
This commit is contained in:
Skye Wanderman-Milne 2019-01-03 14:29:35 -08:00 committed by TensorFlower Gardener
parent 7c9323bedc
commit 9d78259894
5 changed files with 53 additions and 28 deletions

View File

@ -158,6 +158,17 @@ class TensorListConcat : public OpKernel {
std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>;
explicit TensorListConcat(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
// TODO(skyewm): the HasAttr check can be removed once the
// element_shape_except_first_dim attr has been checked in for 2 weeks
// (around 1/14/2019).
if (c->HasAttr("element_shape")) {
PartialTensorShape element_shape;
OP_REQUIRES_OK(c, c->GetAttr("element_shape", &element_shape));
if (!element_shape.unknown_rank()) {
element_shape_except_first_dim_ = PartialTensorShape(
gtl::ArraySlice<int64>(element_shape.dim_sizes()).subspan(1));
}
}
}
~TensorListConcat() {}
@ -178,29 +189,33 @@ class TensorListConcat : public OpKernel {
" but list elements ", DataTypeString(tensor_list->element_dtype)));
// If the TensorList is empty, its element_shape must be fully defined
// except for the first dimension.
PartialTensorShape shape_except_first_dim;
if (!tensor_list->element_shape.unknown_rank()) {
OP_REQUIRES(c, tensor_list->element_shape.dims() >= 1,
errors::InvalidArgument(
"Concat requires elements to be at least vectors, ",
"found scalars instead."));
shape_except_first_dim = PartialTensorShape(
gtl::ArraySlice<int64>(tensor_list->element_shape.dim_sizes())
.subspan(1));
if (!element_shape_except_first_dim_.IsFullyDefined()) {
if (!tensor_list->element_shape.unknown_rank()) {
OP_REQUIRES(c, tensor_list->element_shape.dims() >= 1,
errors::InvalidArgument(
"Concat requires elements to be at least vectors, ",
"found scalars instead."));
PartialTensorShape shape_except_first_dim(
gtl::ArraySlice<int64>(tensor_list->element_shape.dim_sizes())
.subspan(1));
PartialTensorShape tmp = element_shape_except_first_dim_;
OP_REQUIRES_OK(c, tmp.MergeWith(shape_except_first_dim,
&element_shape_except_first_dim_));
}
}
OP_REQUIRES(c,
!tensor_list->tensors.empty() ||
shape_except_first_dim.IsFullyDefined(),
element_shape_except_first_dim_.IsFullyDefined(),
errors::InvalidArgument(
"All except the first dimension must be fully defined ",
"when concating an empty tensor list. element_shape: ",
tensor_list->element_shape.DebugString()));
// 1. Compute the shape of the output tensor.
// If `shape_except_first_dim` is fully-defined we just prepend the leading
// dim to it. Otherwise we use the shape of the first element tensor and
// check to make sure shapes of all tensors are compatible.
// If `element_shape_except_first_dim_` is fully-defined we just prepend the
// leading dim to it. Otherwise we use the shape of the first element tensor
// and check to make sure shapes of all tensors are compatible.
TensorShape output_shape;
if (!shape_except_first_dim.AsTensorShape(&output_shape)) {
if (!element_shape_except_first_dim_.AsTensorShape(&output_shape)) {
const Tensor& element_tensor = tensor_list->tensors[0];
OP_REQUIRES(
c, TensorShapeUtils::IsVectorOrHigher(element_tensor.shape()),
@ -268,6 +283,7 @@ class TensorListConcat : public OpKernel {
private:
DataType element_dtype_;
PartialTensorShape element_shape_except_first_dim_;
};
template <typename Device, typename T>

View File

@ -212,10 +212,16 @@ REGISTER_OP("TensorListConcat")
.Output("tensor: element_dtype")
.Output("lengths: int64")
.Attr("element_dtype: type")
.Attr("element_shape: shape = { unknown_rank: true }")
.SetShapeFn([](shape_inference::InferenceContext* c) {
DataType element_dtype;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &element_dtype));
shape_inference::ShapeHandle element_shape = c->UnknownShape();
PartialTensorShape raw_element_shape;
TF_RETURN_IF_ERROR(c->GetAttr("element_shape", &raw_element_shape));
shape_inference::ShapeHandle element_shape;
TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(raw_element_shape,
&element_shape));
auto* handle_data = c->input_handle_shapes_and_types(0);
if (handle_data != nullptr && handle_data->size() != 1) {
return errors::InvalidArgument(
@ -231,10 +237,10 @@ REGISTER_OP("TensorListConcat")
DataTypeString(list_shape_type.dtype), " but expected type ",
DataTypeString(element_dtype));
}
shape_inference::ShapeHandle ignored;
shape_inference::ShapeHandle merged;
TF_RETURN_IF_ERROR(
c->Merge(element_shape, list_shape_type.shape, &ignored));
element_shape = list_shape_type.shape;
c->Merge(element_shape, list_shape_type.shape, &merged));
element_shape = merged;
}
if (c->RankKnown(element_shape)) {
shape_inference::ShapeHandle result;

View File

@ -71,11 +71,13 @@ def tensor_list_from_tensor(tensor, element_shape, name=None):
name=name)
def tensor_list_concat(input_handle, element_dtype, name=None):
def tensor_list_concat(input_handle, element_dtype, element_shape=None,
name=None):
# Ignore the lengths output of TensorListConcat. It is only used during
# gradient computation.
return gen_list_ops.tensor_list_concat(
input_handle=input_handle, element_dtype=element_dtype, name=name)[0]
input_handle=input_handle, element_dtype=element_dtype,
element_shape=element_shape, name=name)[0]
def tensor_list_split(tensor, element_shape, lengths, name=None):

View File

@ -27,7 +27,6 @@ from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -105,10 +104,6 @@ class PForTest(test.TestCase):
flags.FLAGS.op_conversion_fallback_to_while_loop = False
def test_parallel_iterations(self):
# TODO(b/121334512): Remove this check once this passes in Eager mode.
if context.executing_eagerly():
return
for parallel_iterations in [2, 3, 8, 10]:
x = random_ops.random_uniform([8, 3])

View File

@ -588,10 +588,16 @@ class _GraphTensorArrayV2(object):
def concat(self, name=None):
"""See TensorArray."""
value = list_ops.tensor_list_concat(
input_handle=self._flow, element_dtype=self._dtype, name=name)
if self._element_shape and self._element_shape[0].dims is not None:
value.set_shape([None] + self._element_shape[0].dims[1:])
element_shape = [None] + self._element_shape[0].dims[1:]
else:
element_shape = None
value = list_ops.tensor_list_concat(
input_handle=self._flow,
element_dtype=self._dtype,
element_shape=element_shape,
name=name)
return value
@tf_should_use.should_use_result