[TF] Add support for TensorArrays to tf.data Dataset.

This includes adding support for TensorArray-based aggregation in
tf.data.experimental.scan.

Also fixes some bugs/nits in TensorArray.

PiperOrigin-RevId: 242746755
This commit is contained in:
Eugene Brevdo 2019-04-09 15:00:58 -07:00 committed by TensorFlower Gardener
parent d1e4b00bf6
commit 3a3e2833e4
24 changed files with 562 additions and 70 deletions

View File

@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
def fn():
ta = tensor_array_ops.TensorArray(
dtype=dtypes.float32, tensor_array_name="foo", size=3)
return ta.write(-1, np.int32(7)).flow
return ta.write(-1, constant_op.constant(7)).flow
# Test writing the wrong datatype.
with self.assertRaisesOpError(
"TensorArray dtype is float but op has dtype int32"):
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
# callers provide proper init dtype.
with self.assertRaisesRegexp(
(ValueError, errors.InvalidArgumentError),
r"("
r"conversion requested dtype float32 for Tensor with dtype int32"
r"|"
r"TensorArray dtype is float but op has dtype int32"
r")"):
xla.compile(fn)[0].eval()
@test_util.disable_control_flow_v2("b/124334096 verify dtype")

View File

@ -39,6 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@StatsOptions
@@Structure
@@TFRecordWriter
@@TensorArrayStructure
@@TensorStructure
@@ThreadingOptions
@ -137,6 +138,7 @@ from tensorflow.python.data.ops.optional_ops import OptionalStructure
from tensorflow.python.data.util.structure import NestedStructure
from tensorflow.python.data.util.structure import SparseTensorStructure
from tensorflow.python.data.util.structure import Structure
from tensorflow.python.data.util.structure import TensorArrayStructure
from tensorflow.python.data.util.structure import TensorStructure
# pylint: enable=unused-import

View File

@ -560,6 +560,7 @@ py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:script_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python/data/experimental/ops:scan_ops",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",

View File

@ -30,7 +30,10 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
@ -95,6 +98,83 @@ class ScanTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
def testTensorArraySimple(self):
def scan_fn(ta, x):
return (ta.write(ta.size(), x), ta.stack())
start = tensor_array_ops.TensorArray(
size=0,
element_shape=[],
dtype=dtypes.int64,
dynamic_size=True)
start = start.write(0, -1)
ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn))
self.assertDatasetProduces(
ds,
expected_output=[
[-1],
[-1, 0],
[-1, 0, 1],
[-1, 0, 1, 2],
[-1, 0, 1, 2, 3],
],
requires_initialization=True,
num_test_iterations=2)
def testTensorArrayWithCondReset(self):
def empty():
return tensor_array_ops.TensorArray(
size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
def scan_fn(ta, x):
updated = ta.write(ta.size(), x)
next_iter = control_flow_ops.cond(
math_ops.equal(x % 3, 0), empty, lambda: updated)
return (next_iter, updated.stack())
start = empty()
start = start.write(0, -1)
ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
self.assertDatasetProduces(
ds,
expected_output=[
[-1, 0],
[1],
[1, 2],
[1, 2, 3],
[4],
[4, 5],
],
requires_initialization=True,
num_test_iterations=2)
def testTensorArrayWithCondResetByExternalCaptureBreaks(self):
empty_ta = tensor_array_ops.TensorArray(
size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
def scan_fn(ta, x):
updated = ta.write(ta.size(), x)
# Here, capture empty_ta from outside the function. However, it may be
# either a TF1-style TensorArray or an Eager-style TensorArray.
next_iter = control_flow_ops.cond(
math_ops.equal(x % 3, 0), lambda: empty_ta, lambda: updated)
return (next_iter, updated.stack())
start = empty_ta
start = start.write(0, -1)
with self.assertRaisesRegexp(
NotImplementedError,
r"construct a new TensorArray inside the function"):
dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
def testChangingStateShape(self):
# Test the fixed-point shape invariant calculations: start with
# initial values with known shapes, and use a scan function that

View File

@ -302,6 +302,7 @@ py_library(
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:function",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",

View File

@ -23,7 +23,6 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export
@ -36,14 +35,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
self._input_dataset = input_dataset
with ops.name_scope("initial_state"):
# Convert any `SparseTensorValue`s to `SparseTensor`s and all other
# values to tensors.
self._initial_state = nest.pack_sequence_as(initial_state, [
sparse_tensor.SparseTensor.from_value(t)
if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
t, name="component_%d" % i)
for i, t in enumerate(nest.flatten(initial_state))
])
self._initial_state = structure.normalize_tensors(initial_state)
# Compute initial values for the state classes, shapes and types based on
# the initial state. The shapes may be refined by running `tf_scan_func` one

View File

@ -176,6 +176,7 @@ tf_py_test(
"//tensorflow/python:session",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:training",
"//tensorflow/python/data/ops:dataset_ops",
],
@ -195,6 +196,7 @@ tf_py_test(
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:script_ops",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:session",
],
)
@ -415,6 +417,7 @@ tf_py_test(
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_util",
"//tensorflow/python:variable_scope",
],

View File

@ -24,10 +24,13 @@ import numpy as np
from tensorflow.python.client import session
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
@ -121,6 +124,27 @@ class FlatMapTest(test_base.DatasetTestBase):
expected_output.append([i, 0] if j % 2 == 0 else [0, -i])
self.assertDatasetProduces(dataset, expected_output=expected_output)
def testTensorArray(self):
def _map_fn(i):
i = math_ops.cast(i, dtypes.int32)
return (
tensor_array_ops.TensorArray(
dtype=dtypes.int32, element_shape=(), size=i)
.unstack(math_ops.range(i)))
def _flat_map_fn(x):
self.assertIsInstance(x, tensor_array_ops.TensorArray)
return dataset_ops.Dataset.from_tensor_slices(x.stack())
dataset = dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
expected_output = []
for i in range(10):
for j in range(i):
expected_output.append(j)
self.assertDatasetProduces(dataset, expected_output=expected_output)
if __name__ == "__main__":
test.main()

View File

@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
@ -51,6 +52,17 @@ class FromTensorsTest(test_base.DatasetTestBase):
self.assertDatasetProduces(dataset, expected_output=[components])
def testFromTensorsTensorArray(self):
"""Test a dataset that represents a TensorArray."""
components = (
tensor_array_ops.TensorArray(dtypes.float32, element_shape=(), size=2)
.unstack([1.0, 2.0]))
dataset = dataset_ops.Dataset.from_tensors(components)
self.assertDatasetProduces(
dataset, expected_output=[[1.0, 2.0]], requires_initialization=True)
def testFromTensorsSparse(self):
"""Test a dataset that represents a single tuple of tensors."""
components = (sparse_tensor.SparseTensorValue(

View File

@ -47,6 +47,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -728,6 +729,36 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
dataset,
expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
def testTensorArray(self):
def _tensor_array(i):
i = math_ops.cast(i, dtypes.int32)
return (
tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
.unstack(math_ops.range(i, dtype=dtypes.int32)))
dataset = dataset_ops.Dataset.range(10).map(_tensor_array)
self.assertDatasetProduces(
dataset, expected_output=[list(range(i)) for i in range(10)])
def testTensorArrayChain(self):
def _tensor_array(i):
i = math_ops.cast(i, dtypes.int32)
return (
tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
.unstack(math_ops.range(i, dtype=dtypes.int32)))
def _check(x):
self.assertIsInstance(x, tensor_array_ops.TensorArray)
return x.identity()
dataset = dataset_ops.Dataset.range(10).map(_tensor_array).map(_check)
self.assertDatasetProduces(
dataset,
expected_output=[list(range(i)) for i in range(10)])
@test_util.run_v1_only("b/123904513")
def testParallelMapOutOfRangeError(self):
def raising_py_func(i):

View File

@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.platform import test
@ -63,11 +64,20 @@ class DatasetTestBase(test.TestCase):
mode, it should use an initializable iterator to iterate through the
dataset (e.g. when it contains stateful nodes). Defaults to False.
Returns:
A callable that returns the next element of `dataset`.
A callable that returns the next element of `dataset`. Any `TensorArray`
objects `dataset` outputs are stacked.
"""
def ta_wrapper(gn):
def _wrapper():
r = gn()
if isinstance(r, tensor_array_ops.TensorArray):
return r.stack()
else:
return r
return _wrapper
if context.executing_eagerly():
iterator = iter(dataset)
return iterator._next_internal # pylint: disable=protected-access
return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access
else:
if requires_initialization:
iterator = dataset_ops.make_initializable_iterator(dataset)
@ -75,7 +85,7 @@ class DatasetTestBase(test.TestCase):
else:
iterator = dataset_ops.make_one_shot_iterator(dataset)
get_next = iterator.get_next()
return lambda: get_next
return ta_wrapper(lambda: get_next)
def _compareOutputToExpected(self, result_values, expected_values,
assert_items_equal):
@ -91,7 +101,11 @@ class DatasetTestBase(test.TestCase):
if sparse_tensor.is_sparse(result_value):
self.assertSparseValuesEqual(result_value, expected_value)
else:
self.assertAllEqual(result_value, expected_value)
self.assertAllEqual(
result_value,
expected_value,
msg=("Result value: {}. Expected value: {}"
.format(result_value, expected_value)))
def assertDatasetProduces(self,
dataset,
@ -168,6 +182,7 @@ class DatasetTestBase(test.TestCase):
next1 = self.getNext(dataset1)
next2 = self.getNext(dataset2)
while True:
try:
op1 = self.evaluate(next1())

View File

@ -2070,13 +2070,7 @@ class TensorDataset(DatasetSource):
def __init__(self, tensors):
"""See `Dataset.from_tensors()` for details."""
with ops.name_scope("tensors"):
tensors = nest.pack_sequence_as(tensors, [
sparse_tensor_lib.SparseTensor.from_value(t)
if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
t, name="component_%d" % i)
for i, t in enumerate(nest.flatten(tensors))
])
tensors = structure_lib.normalize_tensors(tensors)
self._structure = structure_lib.Structure.from_value(tensors)
self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access

View File

@ -73,6 +73,7 @@ py_library(
"//tensorflow/python:ops",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
@ -91,6 +92,7 @@ py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_array_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:variables",
"//tensorflow/python/data/kernel_tests:test_base",

View File

@ -27,7 +27,9 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.util.tf_export import tf_export
@ -203,6 +205,8 @@ class Structure(object):
value,
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
return SparseTensorStructure.from_value(value)
elif isinstance(value, tensor_array_ops.TensorArray):
return TensorArrayStructure.from_value(value)
elif isinstance(value, (tuple, dict)):
return NestedStructure.from_value(value)
else:
@ -241,6 +245,33 @@ class Structure(object):
raise NotImplementedError("Structure._to_legacy_output_classes()")
def normalize_tensors(tensors):
"""Converts a nested structure of tensor-like objects to tensors.
* `SparseTensor`-like inputs are converted to `SparseTensor`.
* `TensorArray` inputs are passed through.
* Everything else is converted to a dense `Tensor`.
Args:
tensors: A nested structure of tensor-like, list,
`SparseTensor`, `SparseTensorValue`, or `TensorArray` objects.
Returns:
A nested structure of tensor, `SparseTensor`, or `TensorArray` objects.
"""
flat_tensors = nest.flatten(tensors)
prepared = []
with ops.name_scope("normalize_tensors"):
for i, t in enumerate(flat_tensors):
if sparse_tensor_lib.is_sparse(t):
prepared.append(sparse_tensor_lib.SparseTensor.from_value(t))
elif isinstance(t, tensor_array_ops.TensorArray):
prepared.append(t)
else:
prepared.append(ops.convert_to_tensor(t, name="component_%d" % i))
return nest.pack_sequence_as(tensors, prepared)
def convert_legacy_structure(output_types, output_shapes, output_classes):
"""Returns a `Structure` that represents the given legacy structure.
@ -280,12 +311,19 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
elif issubclass(flat_class, ops.Tensor):
flat_ret.append(TensorStructure(flat_type, flat_shape))
elif issubclass(flat_class, tensor_array_ops.TensorArray):
# We sneaked the dynamic_size and infer_shape into the legacy shape.
flat_ret.append(
TensorArrayStructure(
flat_type, flat_shape[2:],
dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
infer_shape=tensor_shape.dimension_value(flat_shape[1])))
else:
# NOTE(mrry): Since legacy structures produced by iterators only
# comprise Tensors, SparseTensors, and nests, we do not need to
# support all structure types here.
raise TypeError(
"Could not build a structure for output class %r" % flat_type)
"Could not build a structure for output class %r" % (flat_class,))
ret = nest.pack_sequence_as(output_classes, flat_ret)
if isinstance(ret, Structure):
@ -571,3 +609,94 @@ class SparseTensorStructure(Structure):
if self._dense_shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return SparseTensorStructure(self._dtype, self._dense_shape[1:])
@tf_export("data.experimental.TensorArrayStructure")
class TensorArrayStructure(Structure):
"""Represents structural information about a `tf.TensorArray`."""
def __init__(self, dtype, element_shape, dynamic_size, infer_shape):
self._dtype = dtypes.as_dtype(dtype)
self._element_shape = tensor_shape.as_shape(element_shape)
self._dynamic_size = dynamic_size
self._infer_shape = infer_shape
@property
def _flat_shapes(self):
# A TensorArray is represented via its variant object, which is a scalar.
return [tensor_shape.scalar()]
@property
def _flat_types(self):
return [dtypes.variant]
def is_compatible_with(self, other):
return (isinstance(other, TensorArrayStructure) and
self._dtype.is_compatible_with(other._dtype) and
self._element_shape.is_compatible_with(other._element_shape) and
self._dynamic_size == other._dynamic_size)
def _to_tensor_list(self, value):
if not isinstance(value, tensor_array_ops.TensorArray):
raise TypeError("value must be a TensorArray, but saw: {}"
.format(type(value)))
if value.flow is not None and value.flow.dtype == dtypes.variant:
return [value.flow]
else:
# Convert to a TF2-style TensorArray.
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
# "implementation / as_variant" arg to TensorArray constructor.
with ops.name_scope("convert_tensor_array"):
flow = list_ops.tensor_list_from_tensor(
tensor=value.stack(), element_shape=value.element_shape)
return [flow]
def _to_batched_tensor_list(self, value):
raise NotImplementedError("TensorArrayStructure._to_batched_tensor_list")
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
raise ValueError("TensorArrayStructure corresponds to a single "
"tf.variant scalar.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
# This will return a TF2 Graph-style TensorArray because flat_value[0] is
# a variant object. size == -1 implies unknown size.
ret = tensor_array_ops.TensorArray(
dtype=self._dtype,
flow=flat_value[0],
dynamic_size=self._dynamic_size,
infer_shape=self._infer_shape)
ret._element_shape = [self._element_shape]
return ret
@staticmethod
def from_value(value):
if not isinstance(value, tensor_array_ops.TensorArray):
raise TypeError("Expected value to be a TensorArray, but saw: {}".
format(type(value)))
return TensorArrayStructure(
dtype=value.dtype,
element_shape=value.element_shape,
dynamic_size=value.dynamic_size,
infer_shape=value._infer_shape)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
# Sneak the dynamic_size and infer_shape values into the legacy shape.
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
.concatenate(self._element_shape))
def _to_legacy_output_classes(self):
return tensor_array_ops.TensorArray
def _batch(self, batch_size):
raise NotImplementedError("TensorArrayStructure._batch")
def _unbatch(self):
raise NotImplementedError("TensorArrayStructure._unbatch")

View File

@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@ -44,6 +45,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters(
(lambda: constant_op.constant(37.0), structure.TensorStructure,
[dtypes.float32], [[]]),
(lambda: tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(3,), size=0),
structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
(lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
structure.SparseTensorStructure, [dtypes.variant], [None]),
@ -79,6 +83,20 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
variables.Variable(100.0), 42.0,
np.array(42.0, dtype=np.float32)
], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]),
(lambda: tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(3,), size=0),
lambda: [
tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(3,), size=0),
tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(3,), size=10)
],
lambda: [
tensor_array_ops.TensorArray(
dtype=dtypes.int32, element_shape=(3,), size=0),
tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(), size=0)
]),
(lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
lambda: [
@ -137,6 +155,8 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
(lambda: constant_op.constant(37.0),),
(lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),),
(lambda: tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
(lambda: {"a": constant_op.constant(37.0),
"b": constant_op.constant([1, 2, 3])},),
(lambda: {"a": constant_op.constant(37.0),
@ -149,8 +169,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
def testRoundTripConversion(self, value_fn):
value = value_fn()
s = structure.Structure.from_value(value)
before = self.evaluate(value)
after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))
def maybe_stack_ta(v):
if isinstance(v, tensor_array_ops.TensorArray):
return v.stack()
else:
return v
before = self.evaluate(maybe_stack_ta(value))
after = self.evaluate(
maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value))))
flat_before = nest.flatten(before)
flat_after = nest.flatten(after)
@ -343,6 +370,18 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2),
sparse_tensor.SparseTensor,
structure.SparseTensorStructure(dtypes.int32, [2, 2])),
("TensorArray0", dtypes.int32, tensor_shape.as_shape([None, True, 2, 2]),
tensor_array_ops.TensorArray,
structure.TensorArrayStructure(
dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
("TensorArray1", dtypes.int32, tensor_shape.as_shape([True, None, 2, 2]),
tensor_array_ops.TensorArray,
structure.TensorArrayStructure(
dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
("TensorArray2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2]),
tensor_array_ops.TensorArray,
structure.TensorArrayStructure(
dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
("Nest",
{"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)},
{"a": tensor_shape.scalar(),

View File

@ -437,16 +437,22 @@ class TensorArrayTest(test.TestCase):
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
with self.session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32)
# Test writing the wrong datatype
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = ("Invalid data types; op elements string but list elements "
"float")
else:
error_msg = (
"TensorArray dtype is (float|float32) but Op is trying to write "
"dtype string")
with self.assertRaisesOpError(error_msg):
# TODO(b/129870929): Remove the last 2 checks (runtime checks) after
# back back from preferred_dtype= to dtype= in convert_to_tensor. Also
# restrict error check to only TypeError.
error_msg_regex = (
"("
"Expected float32, got 'wrong_type_scalar' of type 'str' instead."
"|"
"Cannot convert provided value to EagerTensor. Provided value: "
"wrong_type_scalar Requested dtype: float"
"|"
"TensorArray dtype is float.* but Op is trying to write dtype string"
"|"
"Invalid data types; op elements string but list elements float"
")")
with self.assertRaisesRegexp(
(TypeError, errors.InvalidArgumentError), error_msg_regex):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and

View File

@ -75,6 +75,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
for out, ta in zip(fn_output, ta_list):
# TODO(agarwal): support returning Operation objects from loop_fn.
if out is not None:
# out may be a ref tensor, wrap it in identity to get a non-ref tensor.
ta = ta.write(i, array_ops.expand_dims(out, 0))
outputs.append(ta)
return tuple([i + 1] + outputs)
@ -86,7 +87,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
ta_list = control_flow_ops.while_loop(
lambda i, *ta: i < iters,
while_body,
[0] + [tensor_array_ops.TensorArray(dtype, iters)
[0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
for dtype in flat_loop_fn_dtypes],
**extra_args)[1:]

View File

@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import contextlib
import traceback
import weakref
from tensorflow.python.eager import context
@ -35,6 +36,7 @@ from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
@ -114,10 +116,8 @@ class _GraphTensorArray(object):
if clear_after_read is None:
clear_after_read = True
self._dynamic_size = None
dynamic_size = dynamic_size or False
self._dtype = dtype
self._dynamic_size = dynamic_size or False
self._dtype = dtypes.as_dtype(dtype).base_dtype
# Used to keep track of what tensors the TensorArray should be
# colocated with. We choose to colocate the TensorArray with the
@ -137,7 +137,7 @@ class _GraphTensorArray(object):
self._element_shape = []
else:
self._infer_shape = True
self._element_shape = [tensor_shape.TensorShape(element_shape)]
self._element_shape = [tensor_shape.as_shape(element_shape)]
with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
if handle is not None:
self._handle = handle
@ -155,7 +155,7 @@ class _GraphTensorArray(object):
size=size,
element_shape=element_shape,
identical_element_shapes=infer_shape,
dynamic_size=dynamic_size,
dynamic_size=self._dynamic_size,
clear_after_read=clear_after_read,
tensor_array_name=tensor_array_name,
name=scope)
@ -177,6 +177,13 @@ class _GraphTensorArray(object):
def handle(self):
return self._handle
@property
def element_shape(self):
if self._element_shape:
return self._element_shape[0]
else:
return tensor_shape.unknown_shape(None)
def _merge_element_shape(self, shape):
"""Changes the element shape of the array given a shape to merge with.
@ -229,6 +236,7 @@ class _GraphTensorArray(object):
colocate_with_first_write_call=self._colocate_with_first_write_call)
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
def grad(self, source, flow=None, name=None):
@ -270,7 +278,10 @@ class _GraphTensorArray(object):
def write(self, index, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
value = ops.convert_to_tensor(value, name="value")
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if self._infer_shape:
self._merge_element_shape(value.shape)
with self._maybe_colocate_with(value):
@ -288,6 +299,7 @@ class _GraphTensorArray(object):
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
def stack(self, name=None):
@ -301,7 +313,7 @@ class _GraphTensorArray(object):
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.TensorShape(None)
element_shape = tensor_shape.unknown_shape(None)
value = gen_data_flow_ops.tensor_array_gather_v3(
handle=self._handle,
indices=indices,
@ -343,7 +355,10 @@ class _GraphTensorArray(object):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayScatter",
[self._handle, value, indices]):
value = ops.convert_to_tensor(value, name="value")
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:])
with self._maybe_colocate_with(value):
@ -361,6 +376,7 @@ class _GraphTensorArray(object):
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
@tf_should_use.should_use_result
@ -368,7 +384,7 @@ class _GraphTensorArray(object):
"""See TensorArray."""
with ops.name_scope(name, "TensorArraySplit",
[self._handle, value, lengths]):
value = ops.convert_to_tensor(value, name="value")
value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
with self._maybe_colocate_with(value):
lengths_64 = math_ops.cast(lengths, dtypes.int64)
if self._infer_shape and not context.executing_eagerly():
@ -392,6 +408,7 @@ class _GraphTensorArray(object):
ta._infer_shape = self._infer_shape
ta._element_shape = self._element_shape
ta._colocate_with = self._colocate_with
ta._dynamic_size = self._dynamic_size
return ta
def size(self, name=None):
@ -471,7 +488,7 @@ class _GraphTensorArrayV2(object):
raise ValueError("Cannot provide both a flow and element_shape "
"at the same time")
self._dtype = dtype
self._dtype = dtypes.as_dtype(dtype).base_dtype
# Record the current static shape for the array elements. The element
# shape is defined either by `element_shape` or the shape of the tensor
@ -482,7 +499,7 @@ class _GraphTensorArrayV2(object):
self._element_shape = []
else:
self._infer_shape = True
self._element_shape = [tensor_shape.TensorShape(element_shape)]
self._element_shape = [tensor_shape.as_shape(element_shape)]
with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
if flow is None:
self._flow = list_ops.tensor_list_reserve(
@ -505,6 +522,13 @@ class _GraphTensorArrayV2(object):
def dtype(self):
return self._dtype
@property
def element_shape(self):
if self._element_shape:
return self._element_shape[0]
else:
return tensor_shape.unknown_shape(None)
@property
def handle(self):
# We intentionally do not raise an error so that legacy while_loop does not
@ -546,7 +570,7 @@ class _GraphTensorArrayV2(object):
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.TensorShape(None)
element_shape = tensor_shape.unknown_shape(None)
value = list_ops.tensor_list_get_item(
input_handle=self._flow,
index=index,
@ -561,7 +585,10 @@ class _GraphTensorArrayV2(object):
def write(self, index, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
value = ops.convert_to_tensor(value, name="value")
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if self._infer_shape:
self._merge_element_shape(value.shape)
flow_out = list_ops.tensor_list_set_item(
@ -578,7 +605,7 @@ class _GraphTensorArrayV2(object):
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.TensorShape(None)
element_shape = tensor_shape.unknown_shape(None)
value = list_ops.tensor_list_stack(
input_handle=self._flow,
element_dtype=self._dtype,
@ -592,7 +619,7 @@ class _GraphTensorArrayV2(object):
if self._element_shape:
element_shape = self._element_shape[0]
else:
element_shape = tensor_shape.TensorShape(None)
element_shape = tensor_shape.unknown_shape(None)
value = list_ops.tensor_list_gather(
input_handle=self._flow,
indices=indices,
@ -621,7 +648,10 @@ class _GraphTensorArrayV2(object):
def unstack(self, value, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
value = ops.convert_to_tensor(value, name="value")
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:])
flow_out = list_ops.tensor_list_from_tensor(
@ -633,7 +663,10 @@ class _GraphTensorArrayV2(object):
"""See TensorArray."""
with ops.name_scope(name, "TensorArrayScatter",
[self._flow, value, indices]):
value = ops.convert_to_tensor(value, name="value")
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if self._infer_shape and not context.executing_eagerly():
self._merge_element_shape(value.shape[1:])
element_shape = self._element_shape[0] if self._element_shape else None
@ -645,7 +678,10 @@ class _GraphTensorArrayV2(object):
def split(self, value, lengths, name=None):
"""See TensorArray."""
with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
value = ops.convert_to_tensor(value, name="value")
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
lengths_64 = math_ops.cast(lengths, dtypes.int64)
if self._infer_shape and not context.executing_eagerly():
clengths = tensor_util.constant_value(lengths_64)
@ -730,10 +766,10 @@ class _EagerTensorArray(object):
# a Tensor
self._flow = constant_op.constant(0, dtype=dtypes.int32)
self._infer_shape = infer_shape
self._element_shape = element_shape
self._element_shape = tensor_shape.as_shape(element_shape)
self._colocate_with_first_write_call = colocate_with_first_write_call
self._dtype = dtype
self._dtype = dtypes.as_dtype(dtype).base_dtype
self._dynamic_size = dynamic_size or False
self._clear_after_read = (
True if clear_after_read is None else clear_after_read)
@ -757,6 +793,13 @@ class _EagerTensorArray(object):
"""For compatibility; handles are not meaningful when eager is enabled."""
return self._handle
@property
def element_shape(self):
if not self._element_shape:
return tensor_shape.unknown_shape(None)
else:
return
def identity(self):
"""See TensorArray."""
return self.parent()
@ -831,14 +874,17 @@ class _EagerTensorArray(object):
self._tensor_array.extend([None for _ in range(index - size + 1)])
if not isinstance(value, ops.EagerTensor):
value = ops.convert_to_tensor(value)
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
if self._infer_shape:
if self._element_shape is None:
self._element_shape = value.shape
elif not self._element_shape.is_compatible_with(value.shape):
if not self._element_shape.is_compatible_with(value.shape):
raise ValueError("Incompatible shape for value (%s), expected (%s)" %
(value.shape.as_list(), self._element_shape.as_list()))
(value.shape, self._element_shape))
else:
self._element_shape = self._element_shape.merge_with(value.shape)
if self._dtype != value.dtype:
raise errors_impl.InvalidArgumentError(
@ -914,8 +960,10 @@ class _EagerTensorArray(object):
def split(self, value, lengths, name=None):
"""See TensorArray."""
# error checking to match graph-mode errors
value = ops.convert_to_tensor(value)
# TODO(b/129870929): Fix after all callers provide proper init dtype.
value = ops.convert_to_tensor(
value, preferred_dtype=self._dtype, name="value")
_check_dtypes(value, self._dtype)
lengths = ops.convert_to_tensor(lengths)
sum_lengths = math_ops.reduce_sum(lengths)
if lengths.shape.ndims != 1:
@ -1017,13 +1065,19 @@ class TensorArray(object):
ValueError: if both handle and tensor_array_name are provided.
TypeError: if handle is provided but is not a Tensor.
"""
if context.executing_eagerly():
if (context.executing_eagerly() and
(flow is None or flow.dtype != dtypes.variant)):
# It is possible to create a Variant-style TensorArray even in eager mode,
# and this is fine but can have performance implications in eager.
# An example of when this happens is if a tf.function returns a
# TensorArray in its output; its flow variant object is returned to Eager.
# This can be wrapped back up in a Variant-style TensorArray.
implementation = _EagerTensorArray
elif (flow is not None and flow.dtype == dtypes.variant or
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
implementation = _GraphTensorArrayV2
else:
if control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
implementation = _GraphTensorArrayV2
else:
implementation = _GraphTensorArray
implementation = _GraphTensorArray
self._implementation = implementation(
dtype,
size=size,
@ -1054,10 +1108,24 @@ class TensorArray(object):
"""The reference to the TensorArray."""
return self._implementation.handle
@property
def element_shape(self):
"""The `tf.TensorShape` of elements in this TensorArray."""
return self._implementation.element_shape
@property
def dynamic_size(self):
"""Python bool; if `True` the TensorArray can grow dynamically."""
return self._implementation._dynamic_size
@property
def _dynamic_size(self):
return self._implementation._dynamic_size
@_dynamic_size.setter
def _dynamic_size(self, dynamic_size):
self._implementation._dynamic_size = dynamic_size
@property
def _infer_shape(self):
return self._implementation._infer_shape
@ -1244,6 +1312,21 @@ class TensorArray(object):
def build_ta_with_new_flow(old_ta, flow):
"""Builds a TensorArray with a new `flow` tensor."""
if not context.executing_eagerly():
# Sometimes we get old_ta as the implementation, sometimes it's the
# TensorArray wrapper object.
impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
else old_ta)
if (not isinstance(impl, _GraphTensorArrayV2) and
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
raise NotImplementedError("Attempting to build a graph-mode TF2-style "
"TensorArray from either an eager-mode "
"TensorArray or a TF1-style TensorArray. "
"This is not currently supported. You may be "
"attempting to capture a TensorArray "
"inside a tf.function or tf.data map function. "
"Instead, construct a new TensorArray inside "
"the function.")
ta = TensorArray(
dtype=old_ta.dtype,
dynamic_size=old_ta._dynamic_size,
@ -1256,3 +1339,13 @@ def build_ta_with_new_flow(old_ta, flow):
return ta
# pylint: enable=protected-access
def _check_dtypes(value, dtype):
if value.dtype != dtype:
logging.error(
"Error: Input value {} has dtype {}, but expected dtype {}. "
"This leads to undefined behavior and will be an error "
"in future versions of TensorFlow. Traceback:\n{}".format(
value, str(value.dtype), str(dtype),
"".join(traceback.format_stack())))

View File

@ -6,6 +6,14 @@ tf_class {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_size"
mtype: "<type \'property\'>"
}
member {
name: "element_shape"
mtype: "<type \'property\'>"
}
member {
name: "flow"
mtype: "<type \'property\'>"

View File

@ -0,0 +1,18 @@
path: "tensorflow.data.experimental.TensorArrayStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -72,6 +72,10 @@ tf_module {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
member {
name: "TensorArrayStructure"
mtype: "<type \'type\'>"
}
member {
name: "TensorStructure"
mtype: "<type \'type\'>"

View File

@ -6,6 +6,14 @@ tf_class {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "dynamic_size"
mtype: "<type \'property\'>"
}
member {
name: "element_shape"
mtype: "<type \'property\'>"
}
member {
name: "flow"
mtype: "<type \'property\'>"

View File

@ -0,0 +1,18 @@
path: "tensorflow.data.experimental.TensorArrayStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -72,6 +72,10 @@ tf_module {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
member {
name: "TensorArrayStructure"
mtype: "<type \'type\'>"
}
member {
name: "TensorStructure"
mtype: "<type \'type\'>"