From 3a3e2833e42c167a50c18636a8ff3604e6050f5b Mon Sep 17 00:00:00 2001 From: Eugene Brevdo Date: Tue, 9 Apr 2019 15:00:58 -0700 Subject: [PATCH] [TF] Add support for TensorArrays to tf.data Dataset. This includes adding support for TensorArray-based aggregation in tf.data.experimental.scan. Also fixes some bugs/nits in TensorArray. PiperOrigin-RevId: 242746755 --- .../compiler/tests/tensor_array_ops_test.py | 13 +- .../python/data/experimental/__init__.py | 2 + .../data/experimental/kernel_tests/BUILD | 1 + .../experimental/kernel_tests/scan_test.py | 80 +++++++++ tensorflow/python/data/experimental/ops/BUILD | 1 + .../python/data/experimental/ops/scan_ops.py | 10 +- tensorflow/python/data/kernel_tests/BUILD | 3 + .../python/data/kernel_tests/flat_map_test.py | 24 +++ .../data/kernel_tests/from_tensors_test.py | 12 ++ .../python/data/kernel_tests/map_test.py | 31 ++++ .../python/data/kernel_tests/test_base.py | 23 ++- tensorflow/python/data/ops/dataset_ops.py | 8 +- tensorflow/python/data/util/BUILD | 2 + tensorflow/python/data/util/structure.py | 131 ++++++++++++++- tensorflow/python/data/util/structure_test.py | 43 ++++- .../kernel_tests/tensor_array_ops_test.py | 26 +-- .../ops/parallel_for/control_flow_ops.py | 3 +- tensorflow/python/ops/tensor_array_ops.py | 159 ++++++++++++++---- .../golden/v1/tensorflow.-tensor-array.pbtxt | 8 + ...experimental.-tensor-array-structure.pbtxt | 18 ++ .../v1/tensorflow.data.experimental.pbtxt | 4 + .../golden/v2/tensorflow.-tensor-array.pbtxt | 8 + ...experimental.-tensor-array-structure.pbtxt | 18 ++ .../v2/tensorflow.data.experimental.pbtxt | 4 + 24 files changed, 562 insertions(+), 70 deletions(-) create mode 100644 tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-tensor-array-structure.pbtxt create mode 100644 tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-tensor-array-structure.pbtxt diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index 8e075df719c..2debbbee8a2 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase): def fn(): ta = tensor_array_ops.TensorArray( dtype=dtypes.float32, tensor_array_name="foo", size=3) - return ta.write(-1, np.int32(7)).flow + return ta.write(-1, constant_op.constant(7)).flow # Test writing the wrong datatype. - with self.assertRaisesOpError( - "TensorArray dtype is float but op has dtype int32"): + # TODO(b/129870929): Remove InvalidArgumentError/second regexp after all + # callers provide proper init dtype. + with self.assertRaisesRegexp( + (ValueError, errors.InvalidArgumentError), + r"(" + r"conversion requested dtype float32 for Tensor with dtype int32" + r"|" + r"TensorArray dtype is float but op has dtype int32" + r")"): xla.compile(fn)[0].eval() @test_util.disable_control_flow_v2("b/124334096 verify dtype") diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py index 3013eff4d7b..519e8d5541a 100644 --- a/tensorflow/python/data/experimental/__init__.py +++ b/tensorflow/python/data/experimental/__init__.py @@ -39,6 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. @@StatsOptions @@Structure @@TFRecordWriter +@@TensorArrayStructure @@TensorStructure @@ThreadingOptions @@ -137,6 +138,7 @@ from tensorflow.python.data.ops.optional_ops import OptionalStructure from tensorflow.python.data.util.structure import NestedStructure from tensorflow.python.data.util.structure import SparseTensorStructure from tensorflow.python.data.util.structure import Structure +from tensorflow.python.data.util.structure import TensorArrayStructure from tensorflow.python.data.util.structure import TensorStructure # pylint: enable=unused-import diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index 317b54adca8..205337cc67a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -560,6 +560,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python/data/experimental/ops:scan_ops", "//tensorflow/python/data/kernel_tests:test_base", "//tensorflow/python/data/ops:dataset_ops", diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py index 24221a1f0f0..0932a25488a 100644 --- a/tensorflow/python/data/experimental/kernel_tests/scan_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py @@ -30,7 +30,10 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -95,6 +98,83 @@ class ScanTest(test_base.DatasetTestBase): with self.assertRaises(errors.OutOfRangeError): self.evaluate(next_element()) + def testTensorArraySimple(self): + + def scan_fn(ta, x): + return (ta.write(ta.size(), x), ta.stack()) + + start = tensor_array_ops.TensorArray( + size=0, + element_shape=[], + dtype=dtypes.int64, + dynamic_size=True) + start = start.write(0, -1) + + ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn)) + + self.assertDatasetProduces( + ds, + expected_output=[ + [-1], + [-1, 0], + [-1, 0, 1], + [-1, 0, 1, 2], + [-1, 0, 1, 2, 3], + ], + requires_initialization=True, + num_test_iterations=2) + + def testTensorArrayWithCondReset(self): + + def empty(): + return tensor_array_ops.TensorArray( + size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) + + def scan_fn(ta, x): + updated = ta.write(ta.size(), x) + next_iter = control_flow_ops.cond( + math_ops.equal(x % 3, 0), empty, lambda: updated) + return (next_iter, updated.stack()) + + start = empty() + start = start.write(0, -1) + + ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn)) + + self.assertDatasetProduces( + ds, + expected_output=[ + [-1, 0], + [1], + [1, 2], + [1, 2, 3], + [4], + [4, 5], + ], + requires_initialization=True, + num_test_iterations=2) + + def testTensorArrayWithCondResetByExternalCaptureBreaks(self): + + empty_ta = tensor_array_ops.TensorArray( + size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True) + + def scan_fn(ta, x): + updated = ta.write(ta.size(), x) + # Here, capture empty_ta from outside the function. However, it may be + # either a TF1-style TensorArray or an Eager-style TensorArray. + next_iter = control_flow_ops.cond( + math_ops.equal(x % 3, 0), lambda: empty_ta, lambda: updated) + return (next_iter, updated.stack()) + + start = empty_ta + start = start.write(0, -1) + + with self.assertRaisesRegexp( + NotImplementedError, + r"construct a new TensorArray inside the function"): + dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn)) + def testChangingStateShape(self): # Test the fixed-point shape invariant calculations: start with # initial values with known shapes, and use a scan function that diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index d79d45272e8..faf4c2d7bd6 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -302,6 +302,7 @@ py_library( "//tensorflow/python:experimental_dataset_ops_gen", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py index 7662626c3a0..0708b64be36 100644 --- a/tensorflow/python/data/experimental/ops/scan_ops.py +++ b/tensorflow/python/data/experimental/ops/scan_ops.py @@ -23,7 +23,6 @@ from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import structure from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.util.tf_export import tf_export @@ -36,14 +35,7 @@ class _ScanDataset(dataset_ops.UnaryDataset): self._input_dataset = input_dataset with ops.name_scope("initial_state"): - # Convert any `SparseTensorValue`s to `SparseTensor`s and all other - # values to tensors. - self._initial_state = nest.pack_sequence_as(initial_state, [ - sparse_tensor.SparseTensor.from_value(t) - if sparse_tensor.is_sparse(t) else ops.convert_to_tensor( - t, name="component_%d" % i) - for i, t in enumerate(nest.flatten(initial_state)) - ]) + self._initial_state = structure.normalize_tensors(initial_state) # Compute initial values for the state classes, shapes and types based on # the initial state. The shapes may be refined by running `tf_scan_func` one diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index b56049f32da..2a4226be8ff 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -176,6 +176,7 @@ tf_py_test( "//tensorflow/python:session", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", ], @@ -195,6 +196,7 @@ tf_py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:script_ops", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python:session", ], ) @@ -415,6 +417,7 @@ tf_py_test( "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:variable_scope", ], diff --git a/tensorflow/python/data/kernel_tests/flat_map_test.py b/tensorflow/python/data/kernel_tests/flat_map_test.py index 69b5fd0d77f..6872f51e529 100644 --- a/tensorflow/python/data/kernel_tests/flat_map_test.py +++ b/tensorflow/python/data/kernel_tests/flat_map_test.py @@ -24,10 +24,13 @@ import numpy as np from tensorflow.python.client import session from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test from tensorflow.python.training import server_lib @@ -121,6 +124,27 @@ class FlatMapTest(test_base.DatasetTestBase): expected_output.append([i, 0] if j % 2 == 0 else [0, -i]) self.assertDatasetProduces(dataset, expected_output=expected_output) + def testTensorArray(self): + def _map_fn(i): + i = math_ops.cast(i, dtypes.int32) + return ( + tensor_array_ops.TensorArray( + dtype=dtypes.int32, element_shape=(), size=i) + .unstack(math_ops.range(i))) + + def _flat_map_fn(x): + self.assertIsInstance(x, tensor_array_ops.TensorArray) + return dataset_ops.Dataset.from_tensor_slices(x.stack()) + + dataset = dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + + expected_output = [] + for i in range(10): + for j in range(i): + expected_output.append(j) + + self.assertDatasetProduces(dataset, expected_output=expected_output) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/kernel_tests/from_tensors_test.py b/tensorflow/python/data/kernel_tests/from_tensors_test.py index e9f1084e042..ce8ba6d517e 100644 --- a/tensorflow/python/data/kernel_tests/from_tensors_test.py +++ b/tensorflow/python/data/kernel_tests/from_tensors_test.py @@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -51,6 +52,17 @@ class FromTensorsTest(test_base.DatasetTestBase): self.assertDatasetProduces(dataset, expected_output=[components]) + def testFromTensorsTensorArray(self): + """Test a dataset that represents a TensorArray.""" + components = ( + tensor_array_ops.TensorArray(dtypes.float32, element_shape=(), size=2) + .unstack([1.0, 2.0])) + + dataset = dataset_ops.Dataset.from_tensors(components) + + self.assertDatasetProduces( + dataset, expected_output=[[1.0, 2.0]], requires_initialization=True) + def testFromTensorsSparse(self): """Test a dataset that represents a single tuple of tensors.""" components = (sparse_tensor.SparseTensorValue( diff --git a/tensorflow/python/data/kernel_tests/map_test.py b/tensorflow/python/data/kernel_tests/map_test.py index e20867d02ec..fefebeb79c6 100644 --- a/tensorflow/python/data/kernel_tests/map_test.py +++ b/tensorflow/python/data/kernel_tests/map_test.py @@ -47,6 +47,7 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import script_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -728,6 +729,36 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase): dataset, expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)]) + def testTensorArray(self): + + def _tensor_array(i): + i = math_ops.cast(i, dtypes.int32) + return ( + tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i) + .unstack(math_ops.range(i, dtype=dtypes.int32))) + + dataset = dataset_ops.Dataset.range(10).map(_tensor_array) + self.assertDatasetProduces( + dataset, expected_output=[list(range(i)) for i in range(10)]) + + def testTensorArrayChain(self): + + def _tensor_array(i): + i = math_ops.cast(i, dtypes.int32) + return ( + tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i) + .unstack(math_ops.range(i, dtype=dtypes.int32))) + + def _check(x): + self.assertIsInstance(x, tensor_array_ops.TensorArray) + return x.identity() + + dataset = dataset_ops.Dataset.range(10).map(_tensor_array).map(_check) + + self.assertDatasetProduces( + dataset, + expected_output=[list(range(i)) for i in range(10)]) + @test_util.run_v1_only("b/123904513") def testParallelMapOutOfRangeError(self): def raising_py_func(i): diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py index 01315e790dc..d18d247e29a 100644 --- a/tensorflow/python/data/kernel_tests/test_base.py +++ b/tensorflow/python/data/kernel_tests/test_base.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.platform import test @@ -63,11 +64,20 @@ class DatasetTestBase(test.TestCase): mode, it should use an initializable iterator to iterate through the dataset (e.g. when it contains stateful nodes). Defaults to False. Returns: - A callable that returns the next element of `dataset`. + A callable that returns the next element of `dataset`. Any `TensorArray` + objects `dataset` outputs are stacked. """ + def ta_wrapper(gn): + def _wrapper(): + r = gn() + if isinstance(r, tensor_array_ops.TensorArray): + return r.stack() + else: + return r + return _wrapper if context.executing_eagerly(): iterator = iter(dataset) - return iterator._next_internal # pylint: disable=protected-access + return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access else: if requires_initialization: iterator = dataset_ops.make_initializable_iterator(dataset) @@ -75,7 +85,7 @@ class DatasetTestBase(test.TestCase): else: iterator = dataset_ops.make_one_shot_iterator(dataset) get_next = iterator.get_next() - return lambda: get_next + return ta_wrapper(lambda: get_next) def _compareOutputToExpected(self, result_values, expected_values, assert_items_equal): @@ -91,7 +101,11 @@ class DatasetTestBase(test.TestCase): if sparse_tensor.is_sparse(result_value): self.assertSparseValuesEqual(result_value, expected_value) else: - self.assertAllEqual(result_value, expected_value) + self.assertAllEqual( + result_value, + expected_value, + msg=("Result value: {}. Expected value: {}" + .format(result_value, expected_value))) def assertDatasetProduces(self, dataset, @@ -168,6 +182,7 @@ class DatasetTestBase(test.TestCase): next1 = self.getNext(dataset1) next2 = self.getNext(dataset2) + while True: try: op1 = self.evaluate(next1()) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 597e6f34b79..ba27e9f9a16 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2070,13 +2070,7 @@ class TensorDataset(DatasetSource): def __init__(self, tensors): """See `Dataset.from_tensors()` for details.""" - with ops.name_scope("tensors"): - tensors = nest.pack_sequence_as(tensors, [ - sparse_tensor_lib.SparseTensor.from_value(t) - if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor( - t, name="component_%d" % i) - for i, t in enumerate(nest.flatten(tensors)) - ]) + tensors = structure_lib.normalize_tensors(tensors) self._structure = structure_lib.Structure.from_value(tensors) self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access diff --git a/tensorflow/python/data/util/BUILD b/tensorflow/python/data/util/BUILD index c98b1f17293..991d02607c1 100644 --- a/tensorflow/python/data/util/BUILD +++ b/tensorflow/python/data/util/BUILD @@ -73,6 +73,7 @@ py_library( "//tensorflow/python:ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -91,6 +92,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_array_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:variables", "//tensorflow/python/data/kernel_tests:test_base", diff --git a/tensorflow/python/data/util/structure.py b/tensorflow/python/data/util/structure.py index 9de0c4da0eb..661a59ce07e 100644 --- a/tensorflow/python/data/util/structure.py +++ b/tensorflow/python/data/util/structure.py @@ -27,7 +27,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import list_ops from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.util.tf_export import tf_export @@ -203,6 +205,8 @@ class Structure(object): value, (sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)): return SparseTensorStructure.from_value(value) + elif isinstance(value, tensor_array_ops.TensorArray): + return TensorArrayStructure.from_value(value) elif isinstance(value, (tuple, dict)): return NestedStructure.from_value(value) else: @@ -241,6 +245,33 @@ class Structure(object): raise NotImplementedError("Structure._to_legacy_output_classes()") +def normalize_tensors(tensors): + """Converts a nested structure of tensor-like objects to tensors. + + * `SparseTensor`-like inputs are converted to `SparseTensor`. + * `TensorArray` inputs are passed through. + * Everything else is converted to a dense `Tensor`. + + Args: + tensors: A nested structure of tensor-like, list, + `SparseTensor`, `SparseTensorValue`, or `TensorArray` objects. + + Returns: + A nested structure of tensor, `SparseTensor`, or `TensorArray` objects. + """ + flat_tensors = nest.flatten(tensors) + prepared = [] + with ops.name_scope("normalize_tensors"): + for i, t in enumerate(flat_tensors): + if sparse_tensor_lib.is_sparse(t): + prepared.append(sparse_tensor_lib.SparseTensor.from_value(t)) + elif isinstance(t, tensor_array_ops.TensorArray): + prepared.append(t) + else: + prepared.append(ops.convert_to_tensor(t, name="component_%d" % i)) + return nest.pack_sequence_as(tensors, prepared) + + def convert_legacy_structure(output_types, output_shapes, output_classes): """Returns a `Structure` that represents the given legacy structure. @@ -280,12 +311,19 @@ def convert_legacy_structure(output_types, output_shapes, output_classes): flat_ret.append(SparseTensorStructure(flat_type, flat_shape)) elif issubclass(flat_class, ops.Tensor): flat_ret.append(TensorStructure(flat_type, flat_shape)) + elif issubclass(flat_class, tensor_array_ops.TensorArray): + # We sneaked the dynamic_size and infer_shape into the legacy shape. + flat_ret.append( + TensorArrayStructure( + flat_type, flat_shape[2:], + dynamic_size=tensor_shape.dimension_value(flat_shape[0]), + infer_shape=tensor_shape.dimension_value(flat_shape[1]))) else: # NOTE(mrry): Since legacy structures produced by iterators only # comprise Tensors, SparseTensors, and nests, we do not need to # support all structure types here. raise TypeError( - "Could not build a structure for output class %r" % flat_type) + "Could not build a structure for output class %r" % (flat_class,)) ret = nest.pack_sequence_as(output_classes, flat_ret) if isinstance(ret, Structure): @@ -571,3 +609,94 @@ class SparseTensorStructure(Structure): if self._dense_shape.ndims == 0: raise ValueError("Unbatching a tensor is only supported for rank >= 1") return SparseTensorStructure(self._dtype, self._dense_shape[1:]) + + +@tf_export("data.experimental.TensorArrayStructure") +class TensorArrayStructure(Structure): + """Represents structural information about a `tf.TensorArray`.""" + + def __init__(self, dtype, element_shape, dynamic_size, infer_shape): + self._dtype = dtypes.as_dtype(dtype) + self._element_shape = tensor_shape.as_shape(element_shape) + self._dynamic_size = dynamic_size + self._infer_shape = infer_shape + + @property + def _flat_shapes(self): + # A TensorArray is represented via its variant object, which is a scalar. + return [tensor_shape.scalar()] + + @property + def _flat_types(self): + return [dtypes.variant] + + def is_compatible_with(self, other): + return (isinstance(other, TensorArrayStructure) and + self._dtype.is_compatible_with(other._dtype) and + self._element_shape.is_compatible_with(other._element_shape) and + self._dynamic_size == other._dynamic_size) + + def _to_tensor_list(self, value): + if not isinstance(value, tensor_array_ops.TensorArray): + raise TypeError("value must be a TensorArray, but saw: {}" + .format(type(value))) + if value.flow is not None and value.flow.dtype == dtypes.variant: + return [value.flow] + else: + # Convert to a TF2-style TensorArray. + # TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or + # "implementation / as_variant" arg to TensorArray constructor. + with ops.name_scope("convert_tensor_array"): + flow = list_ops.tensor_list_from_tensor( + tensor=value.stack(), element_shape=value.element_shape) + return [flow] + + def _to_batched_tensor_list(self, value): + raise NotImplementedError("TensorArrayStructure._to_batched_tensor_list") + + def _from_tensor_list(self, flat_value): + if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or + not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())): + raise ValueError("TensorArrayStructure corresponds to a single " + "tf.variant scalar.") + return self._from_compatible_tensor_list(flat_value) + + def _from_compatible_tensor_list(self, flat_value): + # This will return a TF2 Graph-style TensorArray because flat_value[0] is + # a variant object. size == -1 implies unknown size. + ret = tensor_array_ops.TensorArray( + dtype=self._dtype, + flow=flat_value[0], + dynamic_size=self._dynamic_size, + infer_shape=self._infer_shape) + ret._element_shape = [self._element_shape] + return ret + + @staticmethod + def from_value(value): + if not isinstance(value, tensor_array_ops.TensorArray): + raise TypeError("Expected value to be a TensorArray, but saw: {}". + format(type(value))) + + return TensorArrayStructure( + dtype=value.dtype, + element_shape=value.element_shape, + dynamic_size=value.dynamic_size, + infer_shape=value._infer_shape) + + def _to_legacy_output_types(self): + return self._dtype + + def _to_legacy_output_shapes(self): + # Sneak the dynamic_size and infer_shape values into the legacy shape. + return (tensor_shape.matrix(self._dynamic_size, self._infer_shape) + .concatenate(self._element_shape)) + + def _to_legacy_output_classes(self): + return tensor_array_ops.TensorArray + + def _batch(self, batch_size): + raise NotImplementedError("TensorArrayStructure._batch") + + def _unbatch(self): + raise NotImplementedError("TensorArrayStructure._unbatch") diff --git a/tensorflow/python/data/util/structure_test.py b/tensorflow/python/data/util/structure_test.py index 91dcfa6f608..d292e9c22ee 100644 --- a/tensorflow/python/data/util/structure_test.py +++ b/tensorflow/python/data/util/structure_test.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -44,6 +45,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): @parameterized.parameters( (lambda: constant_op.constant(37.0), structure.TensorStructure, [dtypes.float32], [[]]), + (lambda: tensor_array_ops.TensorArray( + dtype=dtypes.float32, element_shape=(3,), size=0), + structure.TensorArrayStructure, [dtypes.variant], [None, 3]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), structure.SparseTensorStructure, [dtypes.variant], [None]), @@ -79,6 +83,20 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): variables.Variable(100.0), 42.0, np.array(42.0, dtype=np.float32) ], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]), + (lambda: tensor_array_ops.TensorArray( + dtype=dtypes.float32, element_shape=(3,), size=0), + lambda: [ + tensor_array_ops.TensorArray( + dtype=dtypes.float32, element_shape=(3,), size=0), + tensor_array_ops.TensorArray( + dtype=dtypes.float32, element_shape=(3,), size=10) + ], + lambda: [ + tensor_array_ops.TensorArray( + dtype=dtypes.int32, element_shape=(3,), size=0), + tensor_array_ops.TensorArray( + dtype=dtypes.float32, element_shape=(), size=0) + ]), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]), lambda: [ @@ -137,6 +155,8 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): (lambda: constant_op.constant(37.0),), (lambda: sparse_tensor.SparseTensor( indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),), + (lambda: tensor_array_ops.TensorArray( + dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)), (lambda: {"a": constant_op.constant(37.0), "b": constant_op.constant([1, 2, 3])},), (lambda: {"a": constant_op.constant(37.0), @@ -149,8 +169,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): def testRoundTripConversion(self, value_fn): value = value_fn() s = structure.Structure.from_value(value) - before = self.evaluate(value) - after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value))) + def maybe_stack_ta(v): + if isinstance(v, tensor_array_ops.TensorArray): + return v.stack() + else: + return v + + before = self.evaluate(maybe_stack_ta(value)) + after = self.evaluate( + maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value)))) flat_before = nest.flatten(before) flat_after = nest.flatten(after) @@ -343,6 +370,18 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase): ("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2), sparse_tensor.SparseTensor, structure.SparseTensorStructure(dtypes.int32, [2, 2])), + ("TensorArray0", dtypes.int32, tensor_shape.as_shape([None, True, 2, 2]), + tensor_array_ops.TensorArray, + structure.TensorArrayStructure( + dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)), + ("TensorArray1", dtypes.int32, tensor_shape.as_shape([True, None, 2, 2]), + tensor_array_ops.TensorArray, + structure.TensorArrayStructure( + dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)), + ("TensorArray2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2]), + tensor_array_ops.TensorArray, + structure.TensorArrayStructure( + dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)), ("Nest", {"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)}, {"a": tensor_shape.scalar(), diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 86b98f392a7..5bae6c1ffa7 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -437,16 +437,22 @@ class TensorArrayTest(test.TestCase): def testTensorArrayWriteWrongIndexOrDataTypeFails(self): with self.session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) - # Test writing the wrong datatype - if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and - not context.executing_eagerly()): - error_msg = ("Invalid data types; op elements string but list elements " - "float") - else: - error_msg = ( - "TensorArray dtype is (float|float32) but Op is trying to write " - "dtype string") - with self.assertRaisesOpError(error_msg): + # TODO(b/129870929): Remove the last 2 checks (runtime checks) after + # back back from preferred_dtype= to dtype= in convert_to_tensor. Also + # restrict error check to only TypeError. + error_msg_regex = ( + "(" + "Expected float32, got 'wrong_type_scalar' of type 'str' instead." + "|" + "Cannot convert provided value to EagerTensor. Provided value: " + "wrong_type_scalar Requested dtype: float" + "|" + "TensorArray dtype is float.* but Op is trying to write dtype string" + "|" + "Invalid data types; op elements string but list elements float" + ")") + with self.assertRaisesRegexp( + (TypeError, errors.InvalidArgumentError), error_msg_regex): self.evaluate(ta.write(0, "wrong_type_scalar").flow) if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and diff --git a/tensorflow/python/ops/parallel_for/control_flow_ops.py b/tensorflow/python/ops/parallel_for/control_flow_ops.py index 83bf86a5635..5258d6a721a 100644 --- a/tensorflow/python/ops/parallel_for/control_flow_ops.py +++ b/tensorflow/python/ops/parallel_for/control_flow_ops.py @@ -75,6 +75,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): for out, ta in zip(fn_output, ta_list): # TODO(agarwal): support returning Operation objects from loop_fn. if out is not None: + # out may be a ref tensor, wrap it in identity to get a non-ref tensor. ta = ta.write(i, array_ops.expand_dims(out, 0)) outputs.append(ta) return tuple([i + 1] + outputs) @@ -86,7 +87,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None): ta_list = control_flow_ops.while_loop( lambda i, *ta: i < iters, while_body, - [0] + [tensor_array_ops.TensorArray(dtype, iters) + [0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters) for dtype in flat_loop_fn_dtypes], **extra_args)[1:] diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index 59fee70583e..dd3f9de8899 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import contextlib +import traceback import weakref from tensorflow.python.eager import context @@ -35,6 +36,7 @@ from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export @@ -114,10 +116,8 @@ class _GraphTensorArray(object): if clear_after_read is None: clear_after_read = True - self._dynamic_size = None - dynamic_size = dynamic_size or False - - self._dtype = dtype + self._dynamic_size = dynamic_size or False + self._dtype = dtypes.as_dtype(dtype).base_dtype # Used to keep track of what tensors the TensorArray should be # colocated with. We choose to colocate the TensorArray with the @@ -137,7 +137,7 @@ class _GraphTensorArray(object): self._element_shape = [] else: self._infer_shape = True - self._element_shape = [tensor_shape.TensorShape(element_shape)] + self._element_shape = [tensor_shape.as_shape(element_shape)] with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope: if handle is not None: self._handle = handle @@ -155,7 +155,7 @@ class _GraphTensorArray(object): size=size, element_shape=element_shape, identical_element_shapes=infer_shape, - dynamic_size=dynamic_size, + dynamic_size=self._dynamic_size, clear_after_read=clear_after_read, tensor_array_name=tensor_array_name, name=scope) @@ -177,6 +177,13 @@ class _GraphTensorArray(object): def handle(self): return self._handle + @property + def element_shape(self): + if self._element_shape: + return self._element_shape[0] + else: + return tensor_shape.unknown_shape(None) + def _merge_element_shape(self, shape): """Changes the element shape of the array given a shape to merge with. @@ -229,6 +236,7 @@ class _GraphTensorArray(object): colocate_with_first_write_call=self._colocate_with_first_write_call) ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with + ta._dynamic_size = self._dynamic_size return ta def grad(self, source, flow=None, name=None): @@ -270,7 +278,10 @@ class _GraphTensorArray(object): def write(self, index, value, name=None): """See TensorArray.""" with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]): - value = ops.convert_to_tensor(value, name="value") + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) if self._infer_shape: self._merge_element_shape(value.shape) with self._maybe_colocate_with(value): @@ -288,6 +299,7 @@ class _GraphTensorArray(object): ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with + ta._dynamic_size = self._dynamic_size return ta def stack(self, name=None): @@ -301,7 +313,7 @@ class _GraphTensorArray(object): if self._element_shape: element_shape = self._element_shape[0] else: - element_shape = tensor_shape.TensorShape(None) + element_shape = tensor_shape.unknown_shape(None) value = gen_data_flow_ops.tensor_array_gather_v3( handle=self._handle, indices=indices, @@ -343,7 +355,10 @@ class _GraphTensorArray(object): """See TensorArray.""" with ops.name_scope(name, "TensorArrayScatter", [self._handle, value, indices]): - value = ops.convert_to_tensor(value, name="value") + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) if self._infer_shape and not context.executing_eagerly(): self._merge_element_shape(value.shape[1:]) with self._maybe_colocate_with(value): @@ -361,6 +376,7 @@ class _GraphTensorArray(object): ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with + ta._dynamic_size = self._dynamic_size return ta @tf_should_use.should_use_result @@ -368,7 +384,7 @@ class _GraphTensorArray(object): """See TensorArray.""" with ops.name_scope(name, "TensorArraySplit", [self._handle, value, lengths]): - value = ops.convert_to_tensor(value, name="value") + value = ops.convert_to_tensor(value, dtype=self._dtype, name="value") with self._maybe_colocate_with(value): lengths_64 = math_ops.cast(lengths, dtypes.int64) if self._infer_shape and not context.executing_eagerly(): @@ -392,6 +408,7 @@ class _GraphTensorArray(object): ta._infer_shape = self._infer_shape ta._element_shape = self._element_shape ta._colocate_with = self._colocate_with + ta._dynamic_size = self._dynamic_size return ta def size(self, name=None): @@ -471,7 +488,7 @@ class _GraphTensorArrayV2(object): raise ValueError("Cannot provide both a flow and element_shape " "at the same time") - self._dtype = dtype + self._dtype = dtypes.as_dtype(dtype).base_dtype # Record the current static shape for the array elements. The element # shape is defined either by `element_shape` or the shape of the tensor @@ -482,7 +499,7 @@ class _GraphTensorArrayV2(object): self._element_shape = [] else: self._infer_shape = True - self._element_shape = [tensor_shape.TensorShape(element_shape)] + self._element_shape = [tensor_shape.as_shape(element_shape)] with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope: if flow is None: self._flow = list_ops.tensor_list_reserve( @@ -505,6 +522,13 @@ class _GraphTensorArrayV2(object): def dtype(self): return self._dtype + @property + def element_shape(self): + if self._element_shape: + return self._element_shape[0] + else: + return tensor_shape.unknown_shape(None) + @property def handle(self): # We intentionally do not raise an error so that legacy while_loop does not @@ -546,7 +570,7 @@ class _GraphTensorArrayV2(object): if self._element_shape: element_shape = self._element_shape[0] else: - element_shape = tensor_shape.TensorShape(None) + element_shape = tensor_shape.unknown_shape(None) value = list_ops.tensor_list_get_item( input_handle=self._flow, index=index, @@ -561,7 +585,10 @@ class _GraphTensorArrayV2(object): def write(self, index, value, name=None): """See TensorArray.""" with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]): - value = ops.convert_to_tensor(value, name="value") + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) if self._infer_shape: self._merge_element_shape(value.shape) flow_out = list_ops.tensor_list_set_item( @@ -578,7 +605,7 @@ class _GraphTensorArrayV2(object): if self._element_shape: element_shape = self._element_shape[0] else: - element_shape = tensor_shape.TensorShape(None) + element_shape = tensor_shape.unknown_shape(None) value = list_ops.tensor_list_stack( input_handle=self._flow, element_dtype=self._dtype, @@ -592,7 +619,7 @@ class _GraphTensorArrayV2(object): if self._element_shape: element_shape = self._element_shape[0] else: - element_shape = tensor_shape.TensorShape(None) + element_shape = tensor_shape.unknown_shape(None) value = list_ops.tensor_list_gather( input_handle=self._flow, indices=indices, @@ -621,7 +648,10 @@ class _GraphTensorArrayV2(object): def unstack(self, value, name=None): """See TensorArray.""" with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]): - value = ops.convert_to_tensor(value, name="value") + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) if self._infer_shape and not context.executing_eagerly(): self._merge_element_shape(value.shape[1:]) flow_out = list_ops.tensor_list_from_tensor( @@ -633,7 +663,10 @@ class _GraphTensorArrayV2(object): """See TensorArray.""" with ops.name_scope(name, "TensorArrayScatter", [self._flow, value, indices]): - value = ops.convert_to_tensor(value, name="value") + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) if self._infer_shape and not context.executing_eagerly(): self._merge_element_shape(value.shape[1:]) element_shape = self._element_shape[0] if self._element_shape else None @@ -645,7 +678,10 @@ class _GraphTensorArrayV2(object): def split(self, value, lengths, name=None): """See TensorArray.""" with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]): - value = ops.convert_to_tensor(value, name="value") + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) lengths_64 = math_ops.cast(lengths, dtypes.int64) if self._infer_shape and not context.executing_eagerly(): clengths = tensor_util.constant_value(lengths_64) @@ -730,10 +766,10 @@ class _EagerTensorArray(object): # a Tensor self._flow = constant_op.constant(0, dtype=dtypes.int32) self._infer_shape = infer_shape - self._element_shape = element_shape + self._element_shape = tensor_shape.as_shape(element_shape) self._colocate_with_first_write_call = colocate_with_first_write_call - self._dtype = dtype + self._dtype = dtypes.as_dtype(dtype).base_dtype self._dynamic_size = dynamic_size or False self._clear_after_read = ( True if clear_after_read is None else clear_after_read) @@ -757,6 +793,13 @@ class _EagerTensorArray(object): """For compatibility; handles are not meaningful when eager is enabled.""" return self._handle + @property + def element_shape(self): + if not self._element_shape: + return tensor_shape.unknown_shape(None) + else: + return + def identity(self): """See TensorArray.""" return self.parent() @@ -831,14 +874,17 @@ class _EagerTensorArray(object): self._tensor_array.extend([None for _ in range(index - size + 1)]) if not isinstance(value, ops.EagerTensor): - value = ops.convert_to_tensor(value) + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) if self._infer_shape: - if self._element_shape is None: - self._element_shape = value.shape - elif not self._element_shape.is_compatible_with(value.shape): + if not self._element_shape.is_compatible_with(value.shape): raise ValueError("Incompatible shape for value (%s), expected (%s)" % - (value.shape.as_list(), self._element_shape.as_list())) + (value.shape, self._element_shape)) + else: + self._element_shape = self._element_shape.merge_with(value.shape) if self._dtype != value.dtype: raise errors_impl.InvalidArgumentError( @@ -914,8 +960,10 @@ class _EagerTensorArray(object): def split(self, value, lengths, name=None): """See TensorArray.""" - # error checking to match graph-mode errors - value = ops.convert_to_tensor(value) + # TODO(b/129870929): Fix after all callers provide proper init dtype. + value = ops.convert_to_tensor( + value, preferred_dtype=self._dtype, name="value") + _check_dtypes(value, self._dtype) lengths = ops.convert_to_tensor(lengths) sum_lengths = math_ops.reduce_sum(lengths) if lengths.shape.ndims != 1: @@ -1017,13 +1065,19 @@ class TensorArray(object): ValueError: if both handle and tensor_array_name are provided. TypeError: if handle is provided but is not a Tensor. """ - if context.executing_eagerly(): + if (context.executing_eagerly() and + (flow is None or flow.dtype != dtypes.variant)): + # It is possible to create a Variant-style TensorArray even in eager mode, + # and this is fine but can have performance implications in eager. + # An example of when this happens is if a tf.function returns a + # TensorArray in its output; its flow variant object is returned to Eager. + # This can be wrapped back up in a Variant-style TensorArray. implementation = _EagerTensorArray + elif (flow is not None and flow.dtype == dtypes.variant or + control_flow_util.EnableControlFlowV2(ops.get_default_graph())): + implementation = _GraphTensorArrayV2 else: - if control_flow_util.EnableControlFlowV2(ops.get_default_graph()): - implementation = _GraphTensorArrayV2 - else: - implementation = _GraphTensorArray + implementation = _GraphTensorArray self._implementation = implementation( dtype, size=size, @@ -1054,10 +1108,24 @@ class TensorArray(object): """The reference to the TensorArray.""" return self._implementation.handle + @property + def element_shape(self): + """The `tf.TensorShape` of elements in this TensorArray.""" + return self._implementation.element_shape + + @property + def dynamic_size(self): + """Python bool; if `True` the TensorArray can grow dynamically.""" + return self._implementation._dynamic_size + @property def _dynamic_size(self): return self._implementation._dynamic_size + @_dynamic_size.setter + def _dynamic_size(self, dynamic_size): + self._implementation._dynamic_size = dynamic_size + @property def _infer_shape(self): return self._implementation._infer_shape @@ -1244,6 +1312,21 @@ class TensorArray(object): def build_ta_with_new_flow(old_ta, flow): """Builds a TensorArray with a new `flow` tensor.""" + if not context.executing_eagerly(): + # Sometimes we get old_ta as the implementation, sometimes it's the + # TensorArray wrapper object. + impl = (old_ta._implementation if isinstance(old_ta, TensorArray) + else old_ta) + if (not isinstance(impl, _GraphTensorArrayV2) and + control_flow_util.EnableControlFlowV2(ops.get_default_graph())): + raise NotImplementedError("Attempting to build a graph-mode TF2-style " + "TensorArray from either an eager-mode " + "TensorArray or a TF1-style TensorArray. " + "This is not currently supported. You may be " + "attempting to capture a TensorArray " + "inside a tf.function or tf.data map function. " + "Instead, construct a new TensorArray inside " + "the function.") ta = TensorArray( dtype=old_ta.dtype, dynamic_size=old_ta._dynamic_size, @@ -1256,3 +1339,13 @@ def build_ta_with_new_flow(old_ta, flow): return ta # pylint: enable=protected-access + + +def _check_dtypes(value, dtype): + if value.dtype != dtype: + logging.error( + "Error: Input value {} has dtype {}, but expected dtype {}. " + "This leads to undefined behavior and will be an error " + "in future versions of TensorFlow. Traceback:\n{}".format( + value, str(value.dtype), str(dtype), + "".join(traceback.format_stack()))) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-array.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-array.pbtxt index ed088c41ed3..d80e8ac0607 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-tensor-array.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-tensor-array.pbtxt @@ -6,6 +6,14 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dynamic_size" + mtype: "" + } + member { + name: "element_shape" + mtype: "" + } member { name: "flow" mtype: "" diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-tensor-array-structure.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-tensor-array-structure.pbtxt new file mode 100644 index 00000000000..0392b4ff9aa --- /dev/null +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-tensor-array-structure.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.data.experimental.TensorArrayStructure" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_value" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt index 6cec9ac90a2..1442189ca38 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt @@ -72,6 +72,10 @@ tf_module { name: "TFRecordWriter" mtype: "" } + member { + name: "TensorArrayStructure" + mtype: "" + } member { name: "TensorStructure" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt index ed088c41ed3..d80e8ac0607 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-tensor-array.pbtxt @@ -6,6 +6,14 @@ tf_class { name: "dtype" mtype: "" } + member { + name: "dynamic_size" + mtype: "" + } + member { + name: "element_shape" + mtype: "" + } member { name: "flow" mtype: "" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-tensor-array-structure.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-tensor-array-structure.pbtxt new file mode 100644 index 00000000000..0392b4ff9aa --- /dev/null +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-tensor-array-structure.pbtxt @@ -0,0 +1,18 @@ +path: "tensorflow.data.experimental.TensorArrayStructure" +tf_class { + is_instance: "" + is_instance: "" + is_instance: "" + member_method { + name: "__init__" + argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "from_value" + argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None" + } + member_method { + name: "is_compatible_with" + argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None" + } +} diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt index 801ef083886..90ff16902c4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt @@ -72,6 +72,10 @@ tf_module { name: "TFRecordWriter" mtype: "" } + member { + name: "TensorArrayStructure" + mtype: "" + } member { name: "TensorStructure" mtype: ""