[TF] Add support for TensorArrays to tf.data Dataset.
This includes adding support for TensorArray-based aggregation in tf.data.experimental.scan. Also fixes some bugs/nits in TensorArray. PiperOrigin-RevId: 242746755
This commit is contained in:
parent
d1e4b00bf6
commit
3a3e2833e4
@ -387,11 +387,18 @@ class TensorArrayTest(xla_test.XLATestCase):
|
||||
def fn():
|
||||
ta = tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, tensor_array_name="foo", size=3)
|
||||
return ta.write(-1, np.int32(7)).flow
|
||||
return ta.write(-1, constant_op.constant(7)).flow
|
||||
|
||||
# Test writing the wrong datatype.
|
||||
with self.assertRaisesOpError(
|
||||
"TensorArray dtype is float but op has dtype int32"):
|
||||
# TODO(b/129870929): Remove InvalidArgumentError/second regexp after all
|
||||
# callers provide proper init dtype.
|
||||
with self.assertRaisesRegexp(
|
||||
(ValueError, errors.InvalidArgumentError),
|
||||
r"("
|
||||
r"conversion requested dtype float32 for Tensor with dtype int32"
|
||||
r"|"
|
||||
r"TensorArray dtype is float but op has dtype int32"
|
||||
r")"):
|
||||
xla.compile(fn)[0].eval()
|
||||
|
||||
@test_util.disable_control_flow_v2("b/124334096 verify dtype")
|
||||
|
@ -39,6 +39,7 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@StatsOptions
|
||||
@@Structure
|
||||
@@TFRecordWriter
|
||||
@@TensorArrayStructure
|
||||
@@TensorStructure
|
||||
@@ThreadingOptions
|
||||
|
||||
@ -137,6 +138,7 @@ from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
||||
from tensorflow.python.data.util.structure import NestedStructure
|
||||
from tensorflow.python.data.util.structure import SparseTensorStructure
|
||||
from tensorflow.python.data.util.structure import Structure
|
||||
from tensorflow.python.data.util.structure import TensorArrayStructure
|
||||
from tensorflow.python.data.util.structure import TensorStructure
|
||||
# pylint: enable=unused-import
|
||||
|
||||
|
@ -560,6 +560,7 @@ py_test(
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python/data/experimental/ops:scan_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
|
@ -30,7 +30,10 @@ from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -95,6 +98,83 @@ class ScanTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
def testTensorArraySimple(self):
|
||||
|
||||
def scan_fn(ta, x):
|
||||
return (ta.write(ta.size(), x), ta.stack())
|
||||
|
||||
start = tensor_array_ops.TensorArray(
|
||||
size=0,
|
||||
element_shape=[],
|
||||
dtype=dtypes.int64,
|
||||
dynamic_size=True)
|
||||
start = start.write(0, -1)
|
||||
|
||||
ds = dataset_ops.Dataset.range(5).apply(scan_ops.scan(start, scan_fn))
|
||||
|
||||
self.assertDatasetProduces(
|
||||
ds,
|
||||
expected_output=[
|
||||
[-1],
|
||||
[-1, 0],
|
||||
[-1, 0, 1],
|
||||
[-1, 0, 1, 2],
|
||||
[-1, 0, 1, 2, 3],
|
||||
],
|
||||
requires_initialization=True,
|
||||
num_test_iterations=2)
|
||||
|
||||
def testTensorArrayWithCondReset(self):
|
||||
|
||||
def empty():
|
||||
return tensor_array_ops.TensorArray(
|
||||
size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
|
||||
|
||||
def scan_fn(ta, x):
|
||||
updated = ta.write(ta.size(), x)
|
||||
next_iter = control_flow_ops.cond(
|
||||
math_ops.equal(x % 3, 0), empty, lambda: updated)
|
||||
return (next_iter, updated.stack())
|
||||
|
||||
start = empty()
|
||||
start = start.write(0, -1)
|
||||
|
||||
ds = dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
|
||||
|
||||
self.assertDatasetProduces(
|
||||
ds,
|
||||
expected_output=[
|
||||
[-1, 0],
|
||||
[1],
|
||||
[1, 2],
|
||||
[1, 2, 3],
|
||||
[4],
|
||||
[4, 5],
|
||||
],
|
||||
requires_initialization=True,
|
||||
num_test_iterations=2)
|
||||
|
||||
def testTensorArrayWithCondResetByExternalCaptureBreaks(self):
|
||||
|
||||
empty_ta = tensor_array_ops.TensorArray(
|
||||
size=0, element_shape=[], dtype=dtypes.int64, dynamic_size=True)
|
||||
|
||||
def scan_fn(ta, x):
|
||||
updated = ta.write(ta.size(), x)
|
||||
# Here, capture empty_ta from outside the function. However, it may be
|
||||
# either a TF1-style TensorArray or an Eager-style TensorArray.
|
||||
next_iter = control_flow_ops.cond(
|
||||
math_ops.equal(x % 3, 0), lambda: empty_ta, lambda: updated)
|
||||
return (next_iter, updated.stack())
|
||||
|
||||
start = empty_ta
|
||||
start = start.write(0, -1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
NotImplementedError,
|
||||
r"construct a new TensorArray inside the function"):
|
||||
dataset_ops.Dataset.range(6).apply(scan_ops.scan(start, scan_fn))
|
||||
|
||||
def testChangingStateShape(self):
|
||||
# Test the fixed-point shape invariant calculations: start with
|
||||
# initial values with known shapes, and use a scan function that
|
||||
|
@ -302,6 +302,7 @@ py_library(
|
||||
"//tensorflow/python:experimental_dataset_ops_gen",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:function",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
|
@ -23,7 +23,6 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -36,14 +35,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
self._input_dataset = input_dataset
|
||||
|
||||
with ops.name_scope("initial_state"):
|
||||
# Convert any `SparseTensorValue`s to `SparseTensor`s and all other
|
||||
# values to tensors.
|
||||
self._initial_state = nest.pack_sequence_as(initial_state, [
|
||||
sparse_tensor.SparseTensor.from_value(t)
|
||||
if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
|
||||
t, name="component_%d" % i)
|
||||
for i, t in enumerate(nest.flatten(initial_state))
|
||||
])
|
||||
self._initial_state = structure.normalize_tensors(initial_state)
|
||||
|
||||
# Compute initial values for the state classes, shapes and types based on
|
||||
# the initial state. The shapes may be refined by running `tf_scan_func` one
|
||||
|
@ -176,6 +176,7 @@ tf_py_test(
|
||||
"//tensorflow/python:session",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
],
|
||||
@ -195,6 +196,7 @@ tf_py_test(
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:script_ops",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:session",
|
||||
],
|
||||
)
|
||||
@ -415,6 +417,7 @@ tf_py_test(
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
],
|
||||
|
@ -24,10 +24,13 @@ import numpy as np
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import server_lib
|
||||
|
||||
@ -121,6 +124,27 @@ class FlatMapTest(test_base.DatasetTestBase):
|
||||
expected_output.append([i, 0] if j % 2 == 0 else [0, -i])
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
def testTensorArray(self):
|
||||
def _map_fn(i):
|
||||
i = math_ops.cast(i, dtypes.int32)
|
||||
return (
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.int32, element_shape=(), size=i)
|
||||
.unstack(math_ops.range(i)))
|
||||
|
||||
def _flat_map_fn(x):
|
||||
self.assertIsInstance(x, tensor_array_ops.TensorArray)
|
||||
return dataset_ops.Dataset.from_tensor_slices(x.stack())
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
|
||||
|
||||
expected_output = []
|
||||
for i in range(10):
|
||||
for j in range(i):
|
||||
expected_output.append(j)
|
||||
|
||||
self.assertDatasetProduces(dataset, expected_output=expected_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -33,6 +33,7 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -51,6 +52,17 @@ class FromTensorsTest(test_base.DatasetTestBase):
|
||||
|
||||
self.assertDatasetProduces(dataset, expected_output=[components])
|
||||
|
||||
def testFromTensorsTensorArray(self):
|
||||
"""Test a dataset that represents a TensorArray."""
|
||||
components = (
|
||||
tensor_array_ops.TensorArray(dtypes.float32, element_shape=(), size=2)
|
||||
.unstack([1.0, 2.0]))
|
||||
|
||||
dataset = dataset_ops.Dataset.from_tensors(components)
|
||||
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[[1.0, 2.0]], requires_initialization=True)
|
||||
|
||||
def testFromTensorsSparse(self):
|
||||
"""Test a dataset that represents a single tuple of tensors."""
|
||||
components = (sparse_tensor.SparseTensorValue(
|
||||
|
@ -47,6 +47,7 @@ from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -728,6 +729,36 @@ class MapTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset,
|
||||
expected_output=[self.evaluate(_check(_sparse(i))) for i in range(10)])
|
||||
|
||||
def testTensorArray(self):
|
||||
|
||||
def _tensor_array(i):
|
||||
i = math_ops.cast(i, dtypes.int32)
|
||||
return (
|
||||
tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
|
||||
.unstack(math_ops.range(i, dtype=dtypes.int32)))
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).map(_tensor_array)
|
||||
self.assertDatasetProduces(
|
||||
dataset, expected_output=[list(range(i)) for i in range(10)])
|
||||
|
||||
def testTensorArrayChain(self):
|
||||
|
||||
def _tensor_array(i):
|
||||
i = math_ops.cast(i, dtypes.int32)
|
||||
return (
|
||||
tensor_array_ops.TensorArray(dtypes.int32, element_shape=(), size=i)
|
||||
.unstack(math_ops.range(i, dtype=dtypes.int32)))
|
||||
|
||||
def _check(x):
|
||||
self.assertIsInstance(x, tensor_array_ops.TensorArray)
|
||||
return x.identity()
|
||||
|
||||
dataset = dataset_ops.Dataset.range(10).map(_tensor_array).map(_check)
|
||||
|
||||
self.assertDatasetProduces(
|
||||
dataset,
|
||||
expected_output=[list(range(i)) for i in range(10)])
|
||||
|
||||
@test_util.run_v1_only("b/123904513")
|
||||
def testParallelMapOutOfRangeError(self):
|
||||
def raising_py_func(i):
|
||||
|
@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -63,11 +64,20 @@ class DatasetTestBase(test.TestCase):
|
||||
mode, it should use an initializable iterator to iterate through the
|
||||
dataset (e.g. when it contains stateful nodes). Defaults to False.
|
||||
Returns:
|
||||
A callable that returns the next element of `dataset`.
|
||||
A callable that returns the next element of `dataset`. Any `TensorArray`
|
||||
objects `dataset` outputs are stacked.
|
||||
"""
|
||||
def ta_wrapper(gn):
|
||||
def _wrapper():
|
||||
r = gn()
|
||||
if isinstance(r, tensor_array_ops.TensorArray):
|
||||
return r.stack()
|
||||
else:
|
||||
return r
|
||||
return _wrapper
|
||||
if context.executing_eagerly():
|
||||
iterator = iter(dataset)
|
||||
return iterator._next_internal # pylint: disable=protected-access
|
||||
return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access
|
||||
else:
|
||||
if requires_initialization:
|
||||
iterator = dataset_ops.make_initializable_iterator(dataset)
|
||||
@ -75,7 +85,7 @@ class DatasetTestBase(test.TestCase):
|
||||
else:
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
get_next = iterator.get_next()
|
||||
return lambda: get_next
|
||||
return ta_wrapper(lambda: get_next)
|
||||
|
||||
def _compareOutputToExpected(self, result_values, expected_values,
|
||||
assert_items_equal):
|
||||
@ -91,7 +101,11 @@ class DatasetTestBase(test.TestCase):
|
||||
if sparse_tensor.is_sparse(result_value):
|
||||
self.assertSparseValuesEqual(result_value, expected_value)
|
||||
else:
|
||||
self.assertAllEqual(result_value, expected_value)
|
||||
self.assertAllEqual(
|
||||
result_value,
|
||||
expected_value,
|
||||
msg=("Result value: {}. Expected value: {}"
|
||||
.format(result_value, expected_value)))
|
||||
|
||||
def assertDatasetProduces(self,
|
||||
dataset,
|
||||
@ -168,6 +182,7 @@ class DatasetTestBase(test.TestCase):
|
||||
|
||||
next1 = self.getNext(dataset1)
|
||||
next2 = self.getNext(dataset2)
|
||||
|
||||
while True:
|
||||
try:
|
||||
op1 = self.evaluate(next1())
|
||||
|
@ -2070,13 +2070,7 @@ class TensorDataset(DatasetSource):
|
||||
|
||||
def __init__(self, tensors):
|
||||
"""See `Dataset.from_tensors()` for details."""
|
||||
with ops.name_scope("tensors"):
|
||||
tensors = nest.pack_sequence_as(tensors, [
|
||||
sparse_tensor_lib.SparseTensor.from_value(t)
|
||||
if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
|
||||
t, name="component_%d" % i)
|
||||
for i, t in enumerate(nest.flatten(tensors))
|
||||
])
|
||||
tensors = structure_lib.normalize_tensors(tensors)
|
||||
self._structure = structure_lib.Structure.from_value(tensors)
|
||||
self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access
|
||||
|
||||
|
@ -73,6 +73,7 @@ py_library(
|
||||
"//tensorflow/python:ops",
|
||||
"//tensorflow/python:sparse_ops",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
@ -91,6 +92,7 @@ py_test(
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_array_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
|
@ -27,7 +27,9 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -203,6 +205,8 @@ class Structure(object):
|
||||
value,
|
||||
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
|
||||
return SparseTensorStructure.from_value(value)
|
||||
elif isinstance(value, tensor_array_ops.TensorArray):
|
||||
return TensorArrayStructure.from_value(value)
|
||||
elif isinstance(value, (tuple, dict)):
|
||||
return NestedStructure.from_value(value)
|
||||
else:
|
||||
@ -241,6 +245,33 @@ class Structure(object):
|
||||
raise NotImplementedError("Structure._to_legacy_output_classes()")
|
||||
|
||||
|
||||
def normalize_tensors(tensors):
|
||||
"""Converts a nested structure of tensor-like objects to tensors.
|
||||
|
||||
* `SparseTensor`-like inputs are converted to `SparseTensor`.
|
||||
* `TensorArray` inputs are passed through.
|
||||
* Everything else is converted to a dense `Tensor`.
|
||||
|
||||
Args:
|
||||
tensors: A nested structure of tensor-like, list,
|
||||
`SparseTensor`, `SparseTensorValue`, or `TensorArray` objects.
|
||||
|
||||
Returns:
|
||||
A nested structure of tensor, `SparseTensor`, or `TensorArray` objects.
|
||||
"""
|
||||
flat_tensors = nest.flatten(tensors)
|
||||
prepared = []
|
||||
with ops.name_scope("normalize_tensors"):
|
||||
for i, t in enumerate(flat_tensors):
|
||||
if sparse_tensor_lib.is_sparse(t):
|
||||
prepared.append(sparse_tensor_lib.SparseTensor.from_value(t))
|
||||
elif isinstance(t, tensor_array_ops.TensorArray):
|
||||
prepared.append(t)
|
||||
else:
|
||||
prepared.append(ops.convert_to_tensor(t, name="component_%d" % i))
|
||||
return nest.pack_sequence_as(tensors, prepared)
|
||||
|
||||
|
||||
def convert_legacy_structure(output_types, output_shapes, output_classes):
|
||||
"""Returns a `Structure` that represents the given legacy structure.
|
||||
|
||||
@ -280,12 +311,19 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
|
||||
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
|
||||
elif issubclass(flat_class, ops.Tensor):
|
||||
flat_ret.append(TensorStructure(flat_type, flat_shape))
|
||||
elif issubclass(flat_class, tensor_array_ops.TensorArray):
|
||||
# We sneaked the dynamic_size and infer_shape into the legacy shape.
|
||||
flat_ret.append(
|
||||
TensorArrayStructure(
|
||||
flat_type, flat_shape[2:],
|
||||
dynamic_size=tensor_shape.dimension_value(flat_shape[0]),
|
||||
infer_shape=tensor_shape.dimension_value(flat_shape[1])))
|
||||
else:
|
||||
# NOTE(mrry): Since legacy structures produced by iterators only
|
||||
# comprise Tensors, SparseTensors, and nests, we do not need to
|
||||
# support all structure types here.
|
||||
raise TypeError(
|
||||
"Could not build a structure for output class %r" % flat_type)
|
||||
"Could not build a structure for output class %r" % (flat_class,))
|
||||
|
||||
ret = nest.pack_sequence_as(output_classes, flat_ret)
|
||||
if isinstance(ret, Structure):
|
||||
@ -571,3 +609,94 @@ class SparseTensorStructure(Structure):
|
||||
if self._dense_shape.ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return SparseTensorStructure(self._dtype, self._dense_shape[1:])
|
||||
|
||||
|
||||
@tf_export("data.experimental.TensorArrayStructure")
|
||||
class TensorArrayStructure(Structure):
|
||||
"""Represents structural information about a `tf.TensorArray`."""
|
||||
|
||||
def __init__(self, dtype, element_shape, dynamic_size, infer_shape):
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._element_shape = tensor_shape.as_shape(element_shape)
|
||||
self._dynamic_size = dynamic_size
|
||||
self._infer_shape = infer_shape
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
# A TensorArray is represented via its variant object, which is a scalar.
|
||||
return [tensor_shape.scalar()]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
return (isinstance(other, TensorArrayStructure) and
|
||||
self._dtype.is_compatible_with(other._dtype) and
|
||||
self._element_shape.is_compatible_with(other._element_shape) and
|
||||
self._dynamic_size == other._dynamic_size)
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
if not isinstance(value, tensor_array_ops.TensorArray):
|
||||
raise TypeError("value must be a TensorArray, but saw: {}"
|
||||
.format(type(value)))
|
||||
if value.flow is not None and value.flow.dtype == dtypes.variant:
|
||||
return [value.flow]
|
||||
else:
|
||||
# Convert to a TF2-style TensorArray.
|
||||
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
|
||||
# "implementation / as_variant" arg to TensorArray constructor.
|
||||
with ops.name_scope("convert_tensor_array"):
|
||||
flow = list_ops.tensor_list_from_tensor(
|
||||
tensor=value.stack(), element_shape=value.element_shape)
|
||||
return [flow]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
raise NotImplementedError("TensorArrayStructure._to_batched_tensor_list")
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
|
||||
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
|
||||
raise ValueError("TensorArrayStructure corresponds to a single "
|
||||
"tf.variant scalar.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
# This will return a TF2 Graph-style TensorArray because flat_value[0] is
|
||||
# a variant object. size == -1 implies unknown size.
|
||||
ret = tensor_array_ops.TensorArray(
|
||||
dtype=self._dtype,
|
||||
flow=flat_value[0],
|
||||
dynamic_size=self._dynamic_size,
|
||||
infer_shape=self._infer_shape)
|
||||
ret._element_shape = [self._element_shape]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
if not isinstance(value, tensor_array_ops.TensorArray):
|
||||
raise TypeError("Expected value to be a TensorArray, but saw: {}".
|
||||
format(type(value)))
|
||||
|
||||
return TensorArrayStructure(
|
||||
dtype=value.dtype,
|
||||
element_shape=value.element_shape,
|
||||
dynamic_size=value.dynamic_size,
|
||||
infer_shape=value._infer_shape)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
# Sneak the dynamic_size and infer_shape values into the legacy shape.
|
||||
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
|
||||
.concatenate(self._element_shape))
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return tensor_array_ops.TensorArray
|
||||
|
||||
def _batch(self, batch_size):
|
||||
raise NotImplementedError("TensorArrayStructure._batch")
|
||||
|
||||
def _unbatch(self):
|
||||
raise NotImplementedError("TensorArrayStructure._unbatch")
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
@ -44,6 +45,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
@parameterized.parameters(
|
||||
(lambda: constant_op.constant(37.0), structure.TensorStructure,
|
||||
[dtypes.float32], [[]]),
|
||||
(lambda: tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(3,), size=0),
|
||||
structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
|
||||
(lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
|
||||
structure.SparseTensorStructure, [dtypes.variant], [None]),
|
||||
@ -79,6 +83,20 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
variables.Variable(100.0), 42.0,
|
||||
np.array(42.0, dtype=np.float32)
|
||||
], lambda: [constant_op.constant([1.0, 2.0]), constant_op.constant(37)]),
|
||||
(lambda: tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(3,), size=0),
|
||||
lambda: [
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(3,), size=0),
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(3,), size=10)
|
||||
],
|
||||
lambda: [
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.int32, element_shape=(3,), size=0),
|
||||
tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(), size=0)
|
||||
]),
|
||||
(lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
|
||||
lambda: [
|
||||
@ -137,6 +155,8 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
(lambda: constant_op.constant(37.0),),
|
||||
(lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),),
|
||||
(lambda: tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(), size=1).write(0, 7)),
|
||||
(lambda: {"a": constant_op.constant(37.0),
|
||||
"b": constant_op.constant([1, 2, 3])},),
|
||||
(lambda: {"a": constant_op.constant(37.0),
|
||||
@ -149,8 +169,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
def testRoundTripConversion(self, value_fn):
|
||||
value = value_fn()
|
||||
s = structure.Structure.from_value(value)
|
||||
before = self.evaluate(value)
|
||||
after = self.evaluate(s._from_tensor_list(s._to_tensor_list(value)))
|
||||
def maybe_stack_ta(v):
|
||||
if isinstance(v, tensor_array_ops.TensorArray):
|
||||
return v.stack()
|
||||
else:
|
||||
return v
|
||||
|
||||
before = self.evaluate(maybe_stack_ta(value))
|
||||
after = self.evaluate(
|
||||
maybe_stack_ta(s._from_tensor_list(s._to_tensor_list(value))))
|
||||
|
||||
flat_before = nest.flatten(before)
|
||||
flat_after = nest.flatten(after)
|
||||
@ -343,6 +370,18 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
("SparseTensor", dtypes.int32, tensor_shape.matrix(2, 2),
|
||||
sparse_tensor.SparseTensor,
|
||||
structure.SparseTensorStructure(dtypes.int32, [2, 2])),
|
||||
("TensorArray0", dtypes.int32, tensor_shape.as_shape([None, True, 2, 2]),
|
||||
tensor_array_ops.TensorArray,
|
||||
structure.TensorArrayStructure(
|
||||
dtypes.int32, [2, 2], dynamic_size=None, infer_shape=True)),
|
||||
("TensorArray1", dtypes.int32, tensor_shape.as_shape([True, None, 2, 2]),
|
||||
tensor_array_ops.TensorArray,
|
||||
structure.TensorArrayStructure(
|
||||
dtypes.int32, [2, 2], dynamic_size=True, infer_shape=None)),
|
||||
("TensorArray2", dtypes.int32, tensor_shape.as_shape([True, False, 2, 2]),
|
||||
tensor_array_ops.TensorArray,
|
||||
structure.TensorArrayStructure(
|
||||
dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
|
||||
("Nest",
|
||||
{"a": dtypes.float32, "b": (dtypes.int32, dtypes.string)},
|
||||
{"a": tensor_shape.scalar(),
|
||||
|
@ -437,16 +437,22 @@ class TensorArrayTest(test.TestCase):
|
||||
def testTensorArrayWriteWrongIndexOrDataTypeFails(self):
|
||||
with self.session(use_gpu=True):
|
||||
ta = _make_ta(3, "foo", dtype=dtypes.float32)
|
||||
# Test writing the wrong datatype
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
not context.executing_eagerly()):
|
||||
error_msg = ("Invalid data types; op elements string but list elements "
|
||||
"float")
|
||||
else:
|
||||
error_msg = (
|
||||
"TensorArray dtype is (float|float32) but Op is trying to write "
|
||||
"dtype string")
|
||||
with self.assertRaisesOpError(error_msg):
|
||||
# TODO(b/129870929): Remove the last 2 checks (runtime checks) after
|
||||
# back back from preferred_dtype= to dtype= in convert_to_tensor. Also
|
||||
# restrict error check to only TypeError.
|
||||
error_msg_regex = (
|
||||
"("
|
||||
"Expected float32, got 'wrong_type_scalar' of type 'str' instead."
|
||||
"|"
|
||||
"Cannot convert provided value to EagerTensor. Provided value: "
|
||||
"wrong_type_scalar Requested dtype: float"
|
||||
"|"
|
||||
"TensorArray dtype is float.* but Op is trying to write dtype string"
|
||||
"|"
|
||||
"Invalid data types; op elements string but list elements float"
|
||||
")")
|
||||
with self.assertRaisesRegexp(
|
||||
(TypeError, errors.InvalidArgumentError), error_msg_regex):
|
||||
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
|
||||
|
||||
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
|
||||
|
@ -75,6 +75,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
|
||||
for out, ta in zip(fn_output, ta_list):
|
||||
# TODO(agarwal): support returning Operation objects from loop_fn.
|
||||
if out is not None:
|
||||
# out may be a ref tensor, wrap it in identity to get a non-ref tensor.
|
||||
ta = ta.write(i, array_ops.expand_dims(out, 0))
|
||||
outputs.append(ta)
|
||||
return tuple([i + 1] + outputs)
|
||||
@ -86,7 +87,7 @@ def for_loop(loop_fn, loop_fn_dtypes, iters, parallel_iterations=None):
|
||||
ta_list = control_flow_ops.while_loop(
|
||||
lambda i, *ta: i < iters,
|
||||
while_body,
|
||||
[0] + [tensor_array_ops.TensorArray(dtype, iters)
|
||||
[0] + [tensor_array_ops.TensorArray(dtype.base_dtype, iters)
|
||||
for dtype in flat_loop_fn_dtypes],
|
||||
**extra_args)[1:]
|
||||
|
||||
|
@ -20,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import contextlib
|
||||
import traceback
|
||||
import weakref
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
@ -35,6 +36,7 @@ from tensorflow.python.ops import gen_control_flow_ops
|
||||
from tensorflow.python.ops import gen_data_flow_ops
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import tf_should_use
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -114,10 +116,8 @@ class _GraphTensorArray(object):
|
||||
|
||||
if clear_after_read is None:
|
||||
clear_after_read = True
|
||||
self._dynamic_size = None
|
||||
dynamic_size = dynamic_size or False
|
||||
|
||||
self._dtype = dtype
|
||||
self._dynamic_size = dynamic_size or False
|
||||
self._dtype = dtypes.as_dtype(dtype).base_dtype
|
||||
|
||||
# Used to keep track of what tensors the TensorArray should be
|
||||
# colocated with. We choose to colocate the TensorArray with the
|
||||
@ -137,7 +137,7 @@ class _GraphTensorArray(object):
|
||||
self._element_shape = []
|
||||
else:
|
||||
self._infer_shape = True
|
||||
self._element_shape = [tensor_shape.TensorShape(element_shape)]
|
||||
self._element_shape = [tensor_shape.as_shape(element_shape)]
|
||||
with ops.name_scope(name, "TensorArray", [handle, size, flow]) as scope:
|
||||
if handle is not None:
|
||||
self._handle = handle
|
||||
@ -155,7 +155,7 @@ class _GraphTensorArray(object):
|
||||
size=size,
|
||||
element_shape=element_shape,
|
||||
identical_element_shapes=infer_shape,
|
||||
dynamic_size=dynamic_size,
|
||||
dynamic_size=self._dynamic_size,
|
||||
clear_after_read=clear_after_read,
|
||||
tensor_array_name=tensor_array_name,
|
||||
name=scope)
|
||||
@ -177,6 +177,13 @@ class _GraphTensorArray(object):
|
||||
def handle(self):
|
||||
return self._handle
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
if self._element_shape:
|
||||
return self._element_shape[0]
|
||||
else:
|
||||
return tensor_shape.unknown_shape(None)
|
||||
|
||||
def _merge_element_shape(self, shape):
|
||||
"""Changes the element shape of the array given a shape to merge with.
|
||||
|
||||
@ -229,6 +236,7 @@ class _GraphTensorArray(object):
|
||||
colocate_with_first_write_call=self._colocate_with_first_write_call)
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
|
||||
def grad(self, source, flow=None, name=None):
|
||||
@ -270,7 +278,10 @@ class _GraphTensorArray(object):
|
||||
def write(self, index, value, name=None):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayWrite", [self._handle, index, value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
if self._infer_shape:
|
||||
self._merge_element_shape(value.shape)
|
||||
with self._maybe_colocate_with(value):
|
||||
@ -288,6 +299,7 @@ class _GraphTensorArray(object):
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
|
||||
def stack(self, name=None):
|
||||
@ -301,7 +313,7 @@ class _GraphTensorArray(object):
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.TensorShape(None)
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = gen_data_flow_ops.tensor_array_gather_v3(
|
||||
handle=self._handle,
|
||||
indices=indices,
|
||||
@ -343,7 +355,10 @@ class _GraphTensorArray(object):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayScatter",
|
||||
[self._handle, value, indices]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
self._merge_element_shape(value.shape[1:])
|
||||
with self._maybe_colocate_with(value):
|
||||
@ -361,6 +376,7 @@ class _GraphTensorArray(object):
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
|
||||
@tf_should_use.should_use_result
|
||||
@ -368,7 +384,7 @@ class _GraphTensorArray(object):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArraySplit",
|
||||
[self._handle, value, lengths]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
value = ops.convert_to_tensor(value, dtype=self._dtype, name="value")
|
||||
with self._maybe_colocate_with(value):
|
||||
lengths_64 = math_ops.cast(lengths, dtypes.int64)
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
@ -392,6 +408,7 @@ class _GraphTensorArray(object):
|
||||
ta._infer_shape = self._infer_shape
|
||||
ta._element_shape = self._element_shape
|
||||
ta._colocate_with = self._colocate_with
|
||||
ta._dynamic_size = self._dynamic_size
|
||||
return ta
|
||||
|
||||
def size(self, name=None):
|
||||
@ -471,7 +488,7 @@ class _GraphTensorArrayV2(object):
|
||||
raise ValueError("Cannot provide both a flow and element_shape "
|
||||
"at the same time")
|
||||
|
||||
self._dtype = dtype
|
||||
self._dtype = dtypes.as_dtype(dtype).base_dtype
|
||||
|
||||
# Record the current static shape for the array elements. The element
|
||||
# shape is defined either by `element_shape` or the shape of the tensor
|
||||
@ -482,7 +499,7 @@ class _GraphTensorArrayV2(object):
|
||||
self._element_shape = []
|
||||
else:
|
||||
self._infer_shape = True
|
||||
self._element_shape = [tensor_shape.TensorShape(element_shape)]
|
||||
self._element_shape = [tensor_shape.as_shape(element_shape)]
|
||||
with ops.name_scope(name, "TensorArrayV2", [size, flow]) as scope:
|
||||
if flow is None:
|
||||
self._flow = list_ops.tensor_list_reserve(
|
||||
@ -505,6 +522,13 @@ class _GraphTensorArrayV2(object):
|
||||
def dtype(self):
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
if self._element_shape:
|
||||
return self._element_shape[0]
|
||||
else:
|
||||
return tensor_shape.unknown_shape(None)
|
||||
|
||||
@property
|
||||
def handle(self):
|
||||
# We intentionally do not raise an error so that legacy while_loop does not
|
||||
@ -546,7 +570,7 @@ class _GraphTensorArrayV2(object):
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.TensorShape(None)
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = list_ops.tensor_list_get_item(
|
||||
input_handle=self._flow,
|
||||
index=index,
|
||||
@ -561,7 +585,10 @@ class _GraphTensorArrayV2(object):
|
||||
def write(self, index, value, name=None):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayV2Write", [self._flow, index, value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
if self._infer_shape:
|
||||
self._merge_element_shape(value.shape)
|
||||
flow_out = list_ops.tensor_list_set_item(
|
||||
@ -578,7 +605,7 @@ class _GraphTensorArrayV2(object):
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.TensorShape(None)
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = list_ops.tensor_list_stack(
|
||||
input_handle=self._flow,
|
||||
element_dtype=self._dtype,
|
||||
@ -592,7 +619,7 @@ class _GraphTensorArrayV2(object):
|
||||
if self._element_shape:
|
||||
element_shape = self._element_shape[0]
|
||||
else:
|
||||
element_shape = tensor_shape.TensorShape(None)
|
||||
element_shape = tensor_shape.unknown_shape(None)
|
||||
value = list_ops.tensor_list_gather(
|
||||
input_handle=self._flow,
|
||||
indices=indices,
|
||||
@ -621,7 +648,10 @@ class _GraphTensorArrayV2(object):
|
||||
def unstack(self, value, name=None):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayUnstack", [self._flow, value]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
self._merge_element_shape(value.shape[1:])
|
||||
flow_out = list_ops.tensor_list_from_tensor(
|
||||
@ -633,7 +663,10 @@ class _GraphTensorArrayV2(object):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArrayScatter",
|
||||
[self._flow, value, indices]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
self._merge_element_shape(value.shape[1:])
|
||||
element_shape = self._element_shape[0] if self._element_shape else None
|
||||
@ -645,7 +678,10 @@ class _GraphTensorArrayV2(object):
|
||||
def split(self, value, lengths, name=None):
|
||||
"""See TensorArray."""
|
||||
with ops.name_scope(name, "TensorArraySplit", [self._flow, value, lengths]):
|
||||
value = ops.convert_to_tensor(value, name="value")
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
lengths_64 = math_ops.cast(lengths, dtypes.int64)
|
||||
if self._infer_shape and not context.executing_eagerly():
|
||||
clengths = tensor_util.constant_value(lengths_64)
|
||||
@ -730,10 +766,10 @@ class _EagerTensorArray(object):
|
||||
# a Tensor
|
||||
self._flow = constant_op.constant(0, dtype=dtypes.int32)
|
||||
self._infer_shape = infer_shape
|
||||
self._element_shape = element_shape
|
||||
self._element_shape = tensor_shape.as_shape(element_shape)
|
||||
self._colocate_with_first_write_call = colocate_with_first_write_call
|
||||
|
||||
self._dtype = dtype
|
||||
self._dtype = dtypes.as_dtype(dtype).base_dtype
|
||||
self._dynamic_size = dynamic_size or False
|
||||
self._clear_after_read = (
|
||||
True if clear_after_read is None else clear_after_read)
|
||||
@ -757,6 +793,13 @@ class _EagerTensorArray(object):
|
||||
"""For compatibility; handles are not meaningful when eager is enabled."""
|
||||
return self._handle
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
if not self._element_shape:
|
||||
return tensor_shape.unknown_shape(None)
|
||||
else:
|
||||
return
|
||||
|
||||
def identity(self):
|
||||
"""See TensorArray."""
|
||||
return self.parent()
|
||||
@ -831,14 +874,17 @@ class _EagerTensorArray(object):
|
||||
self._tensor_array.extend([None for _ in range(index - size + 1)])
|
||||
|
||||
if not isinstance(value, ops.EagerTensor):
|
||||
value = ops.convert_to_tensor(value)
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
|
||||
if self._infer_shape:
|
||||
if self._element_shape is None:
|
||||
self._element_shape = value.shape
|
||||
elif not self._element_shape.is_compatible_with(value.shape):
|
||||
if not self._element_shape.is_compatible_with(value.shape):
|
||||
raise ValueError("Incompatible shape for value (%s), expected (%s)" %
|
||||
(value.shape.as_list(), self._element_shape.as_list()))
|
||||
(value.shape, self._element_shape))
|
||||
else:
|
||||
self._element_shape = self._element_shape.merge_with(value.shape)
|
||||
|
||||
if self._dtype != value.dtype:
|
||||
raise errors_impl.InvalidArgumentError(
|
||||
@ -914,8 +960,10 @@ class _EagerTensorArray(object):
|
||||
|
||||
def split(self, value, lengths, name=None):
|
||||
"""See TensorArray."""
|
||||
# error checking to match graph-mode errors
|
||||
value = ops.convert_to_tensor(value)
|
||||
# TODO(b/129870929): Fix after all callers provide proper init dtype.
|
||||
value = ops.convert_to_tensor(
|
||||
value, preferred_dtype=self._dtype, name="value")
|
||||
_check_dtypes(value, self._dtype)
|
||||
lengths = ops.convert_to_tensor(lengths)
|
||||
sum_lengths = math_ops.reduce_sum(lengths)
|
||||
if lengths.shape.ndims != 1:
|
||||
@ -1017,13 +1065,19 @@ class TensorArray(object):
|
||||
ValueError: if both handle and tensor_array_name are provided.
|
||||
TypeError: if handle is provided but is not a Tensor.
|
||||
"""
|
||||
if context.executing_eagerly():
|
||||
if (context.executing_eagerly() and
|
||||
(flow is None or flow.dtype != dtypes.variant)):
|
||||
# It is possible to create a Variant-style TensorArray even in eager mode,
|
||||
# and this is fine but can have performance implications in eager.
|
||||
# An example of when this happens is if a tf.function returns a
|
||||
# TensorArray in its output; its flow variant object is returned to Eager.
|
||||
# This can be wrapped back up in a Variant-style TensorArray.
|
||||
implementation = _EagerTensorArray
|
||||
elif (flow is not None and flow.dtype == dtypes.variant or
|
||||
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
|
||||
implementation = _GraphTensorArrayV2
|
||||
else:
|
||||
if control_flow_util.EnableControlFlowV2(ops.get_default_graph()):
|
||||
implementation = _GraphTensorArrayV2
|
||||
else:
|
||||
implementation = _GraphTensorArray
|
||||
implementation = _GraphTensorArray
|
||||
self._implementation = implementation(
|
||||
dtype,
|
||||
size=size,
|
||||
@ -1054,10 +1108,24 @@ class TensorArray(object):
|
||||
"""The reference to the TensorArray."""
|
||||
return self._implementation.handle
|
||||
|
||||
@property
|
||||
def element_shape(self):
|
||||
"""The `tf.TensorShape` of elements in this TensorArray."""
|
||||
return self._implementation.element_shape
|
||||
|
||||
@property
|
||||
def dynamic_size(self):
|
||||
"""Python bool; if `True` the TensorArray can grow dynamically."""
|
||||
return self._implementation._dynamic_size
|
||||
|
||||
@property
|
||||
def _dynamic_size(self):
|
||||
return self._implementation._dynamic_size
|
||||
|
||||
@_dynamic_size.setter
|
||||
def _dynamic_size(self, dynamic_size):
|
||||
self._implementation._dynamic_size = dynamic_size
|
||||
|
||||
@property
|
||||
def _infer_shape(self):
|
||||
return self._implementation._infer_shape
|
||||
@ -1244,6 +1312,21 @@ class TensorArray(object):
|
||||
|
||||
def build_ta_with_new_flow(old_ta, flow):
|
||||
"""Builds a TensorArray with a new `flow` tensor."""
|
||||
if not context.executing_eagerly():
|
||||
# Sometimes we get old_ta as the implementation, sometimes it's the
|
||||
# TensorArray wrapper object.
|
||||
impl = (old_ta._implementation if isinstance(old_ta, TensorArray)
|
||||
else old_ta)
|
||||
if (not isinstance(impl, _GraphTensorArrayV2) and
|
||||
control_flow_util.EnableControlFlowV2(ops.get_default_graph())):
|
||||
raise NotImplementedError("Attempting to build a graph-mode TF2-style "
|
||||
"TensorArray from either an eager-mode "
|
||||
"TensorArray or a TF1-style TensorArray. "
|
||||
"This is not currently supported. You may be "
|
||||
"attempting to capture a TensorArray "
|
||||
"inside a tf.function or tf.data map function. "
|
||||
"Instead, construct a new TensorArray inside "
|
||||
"the function.")
|
||||
ta = TensorArray(
|
||||
dtype=old_ta.dtype,
|
||||
dynamic_size=old_ta._dynamic_size,
|
||||
@ -1256,3 +1339,13 @@ def build_ta_with_new_flow(old_ta, flow):
|
||||
return ta
|
||||
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
def _check_dtypes(value, dtype):
|
||||
if value.dtype != dtype:
|
||||
logging.error(
|
||||
"Error: Input value {} has dtype {}, but expected dtype {}. "
|
||||
"This leads to undefined behavior and will be an error "
|
||||
"in future versions of TensorFlow. Traceback:\n{}".format(
|
||||
value, str(value.dtype), str(dtype),
|
||||
"".join(traceback.format_stack())))
|
||||
|
@ -6,6 +6,14 @@ tf_class {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "element_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "flow"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -0,0 +1,18 @@
|
||||
path: "tensorflow.data.experimental.TensorArrayStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -72,6 +72,10 @@ tf_module {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorArrayStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -6,6 +6,14 @@ tf_class {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "dynamic_size"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "element_shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "flow"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -0,0 +1,18 @@
|
||||
path: "tensorflow.data.experimental.TensorArrayStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -72,6 +72,10 @@ tf_module {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorArrayStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user