parent
80f3b787e4
commit
4f92a46fa8
@ -165,7 +165,7 @@ py_library(
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/util:nest",
|
||||
"//tensorflow/python/data/util:sparse",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -21,6 +21,7 @@ import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
@ -448,7 +449,10 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
|
||||
def _make_reduce_func(self, reduce_func, input_dataset):
|
||||
"""Make wrapping defun for reduce_func."""
|
||||
nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
|
||||
nested_dataset = dataset_ops.DatasetStructure(
|
||||
structure.Structure._from_legacy_structure( # pylint: disable=protected-access
|
||||
input_dataset.output_types, input_dataset.output_shapes,
|
||||
input_dataset.output_classes))
|
||||
wrapped_func = dataset_ops.StructuredFunctionWrapper(
|
||||
reduce_func,
|
||||
self._transformation_name(),
|
||||
@ -456,11 +460,13 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
|
||||
input_shapes=(tensor_shape.scalar(), nested_dataset),
|
||||
input_types=(dtypes.int64, nested_dataset))
|
||||
if not isinstance(
|
||||
wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
|
||||
wrapped_func.output_structure, dataset_ops.DatasetStructure):
|
||||
raise TypeError("`reduce_func` must return a `Dataset` object.")
|
||||
self._output_classes = wrapped_func.output_classes.output_classes
|
||||
self._output_types = wrapped_func.output_types.output_types
|
||||
self._output_shapes = wrapped_func.output_shapes.output_shapes
|
||||
# pylint: disable=protected-access
|
||||
element_structure = wrapped_func.output_structure._element_structure
|
||||
self._output_classes = element_structure._to_legacy_output_classes()
|
||||
self._output_types = element_structure._to_legacy_output_types()
|
||||
self._output_shapes = element_structure._to_legacy_output_shapes()
|
||||
self._reduce_func = wrapped_func.function
|
||||
|
||||
@property
|
||||
|
@ -117,8 +117,12 @@ tf_py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
"//third_party/py/numpy",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/data/ops:optional_ops",
|
||||
"//tensorflow/python/data/util:structure",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -24,10 +24,14 @@ import numpy as np
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import optional_ops
|
||||
from tensorflow.python.data.ops import readers
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@ -249,6 +253,63 @@ class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertTrue(ds.options().experimental_autotune)
|
||||
self.assertTrue(ds.options().experimental_filter_fusion)
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters(
|
||||
("Tensor", lambda: constant_op.constant(37.0),
|
||||
structure.TensorStructure(dtypes.float32, [])),
|
||||
("SparseTensor", lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
|
||||
dense_shape=[1]),
|
||||
structure.SparseTensorStructure(dtypes.int32, [1])),
|
||||
("Nest", lambda: {
|
||||
"a": constant_op.constant(37.0),
|
||||
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
|
||||
structure.NestedStructure({
|
||||
"a": structure.TensorStructure(dtypes.float32, []),
|
||||
"b": (structure.TensorStructure(dtypes.string, [1]),
|
||||
structure.TensorStructure(dtypes.string, []))})),
|
||||
("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
|
||||
constant_op.constant([1, 2, 3])),
|
||||
dataset_ops.DatasetStructure(
|
||||
structure.TensorStructure(dtypes.int32, []))),
|
||||
("Optional", lambda: optional_ops.Optional.from_value(37.0),
|
||||
optional_ops.OptionalStructure(
|
||||
structure.TensorStructure(dtypes.float32, []))),
|
||||
)
|
||||
def testDatasetStructure(self, tf_value_fn, expected_element_structure):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
|
||||
dataset_structure = structure.Structure.from_value(dataset)
|
||||
self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)
|
||||
|
||||
# TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
|
||||
# the element structure.
|
||||
self.assertTrue(expected_element_structure.is_compatible_with(
|
||||
dataset_structure._element_structure))
|
||||
self.assertTrue(dataset_structure._element_structure.is_compatible_with(
|
||||
expected_element_structure))
|
||||
|
||||
self.assertEqual([dtypes.variant], dataset_structure._flat_types)
|
||||
self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)
|
||||
|
||||
# Assert that the `Dataset` survives a round-trip via _from_tensor_list()
|
||||
# and _to_tensor_list().
|
||||
round_trip_dataset = dataset_structure._from_tensor_list(
|
||||
dataset_structure._to_tensor_list(dataset))
|
||||
|
||||
value = tf_value_fn()
|
||||
|
||||
if isinstance(value, dataset_ops.Dataset):
|
||||
self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
|
||||
elif isinstance(value, optional_ops.Optional):
|
||||
self.assertDatasetProduces(
|
||||
round_trip_dataset.map(lambda opt: opt.get_value()),
|
||||
[self.evaluate(value.get_value())],
|
||||
requires_initialization=True)
|
||||
else:
|
||||
self.assertDatasetProduces(
|
||||
round_trip_dataset, [self.evaluate(tf_value_fn())],
|
||||
requires_initialization=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -31,6 +31,7 @@ from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import random_seed
|
||||
from tensorflow.python.data.util import sparse
|
||||
from tensorflow.python.data.util import structure as structure_lib
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -1868,57 +1869,6 @@ class SparseTensorSliceDataset(DatasetSource):
|
||||
return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)
|
||||
|
||||
|
||||
class _NestedDatasetComponent(object):
|
||||
"""The structure of a `Dataset` nested in a component of another `Dataset`.
|
||||
|
||||
A `StructuredFunctionWrapper` around a function that returns a `Dataset` as
|
||||
one of its components will have a `NestedDatasetComponent` in the
|
||||
corresponding position in the `output_classes`, `output_shapes`, and
|
||||
`output_types` properties.
|
||||
|
||||
TODO(b/110122868): Add this class, or something equivalent, to the public API.
|
||||
We are considering revising the public API for accessing Dataset structure
|
||||
(`output_classes` etc.) based on experience with nested datasets and other
|
||||
custom component types.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset=None,
|
||||
output_shapes=None,
|
||||
output_types=None,
|
||||
output_classes=None):
|
||||
if dataset is None:
|
||||
if (output_classes is None or output_shapes is None or
|
||||
output_types is None):
|
||||
raise ValueError(
|
||||
"Either `dataset`, or all of `output_classes`, "
|
||||
"`output_shapes`, and `output_types` must be specified.")
|
||||
self._output_classes = output_classes
|
||||
self._output_shapes = output_shapes
|
||||
self._output_types = output_types
|
||||
else:
|
||||
if not (output_classes is None and output_shapes is None and
|
||||
output_types is None):
|
||||
raise ValueError(
|
||||
"Either `dataset`, or all of `output_classes`, "
|
||||
"`output_shapes`, and `output_types` must be specified.")
|
||||
self._output_classes = dataset.output_classes
|
||||
self._output_shapes = dataset.output_shapes
|
||||
self._output_types = dataset.output_types
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._output_classes
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
|
||||
|
||||
class _VariantDataset(DatasetV2):
|
||||
"""A Dataset wrapper around a `tf.variant`-typed function argument."""
|
||||
|
||||
@ -1935,15 +1885,73 @@ class _VariantDataset(DatasetV2):
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._structure.output_classes
|
||||
return self._structure._to_legacy_output_classes() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._structure.output_shapes
|
||||
return self._structure._to_legacy_output_shapes() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._structure.output_types
|
||||
return self._structure._to_legacy_output_types() # pylint: disable=protected-access
|
||||
|
||||
|
||||
class DatasetStructure(structure_lib.Structure):
|
||||
"""Represents a `Dataset` of structured values."""
|
||||
|
||||
def __init__(self, element_structure):
|
||||
self._element_structure = element_structure
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
return [tensor_shape.scalar()]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
# pylint: disable=protected-access
|
||||
return (isinstance(other, DatasetStructure) and
|
||||
self._element_structure.is_compatible_with(
|
||||
other._element_structure))
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return [value._as_variant_tensor()] # pylint: disable=protected-access
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
|
||||
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
|
||||
raise ValueError(
|
||||
"DatasetStructure corresponds to a single tf.variant scalar.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
# pylint: disable=protected-access
|
||||
return _VariantDataset(flat_value[0], self._element_structure)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
# TODO(b/110122868): We can simplify this when a `Dataset` object has a
|
||||
# `Structure`-valued property.
|
||||
element_structure = structure_lib.Structure._from_legacy_structure(
|
||||
value.output_types, value.output_shapes, value.output_classes)
|
||||
return DatasetStructure(element_structure)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
structure_lib.Structure._register_custom_converter(DatasetV2,
|
||||
DatasetStructure.from_value)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
class StructuredFunctionWrapper(object):
|
||||
@ -2001,6 +2009,9 @@ class StructuredFunctionWrapper(object):
|
||||
self._input_types = dataset.output_types
|
||||
self._input_classes = dataset.output_classes
|
||||
|
||||
self._input_structure = structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access
|
||||
self._input_types, self._input_shapes, self._input_classes)
|
||||
|
||||
self._transformation_name = transformation_name
|
||||
readable_transformation_name = transformation_name.replace(
|
||||
".", "_")[:-2] if len(transformation_name) > 2 else ""
|
||||
@ -2008,35 +2019,18 @@ class StructuredFunctionWrapper(object):
|
||||
readable_transformation_name,
|
||||
function_utils.get_func_name(func),
|
||||
str(ops.uid())
|
||||
|
||||
])
|
||||
|
||||
if defun_kwargs is None:
|
||||
defun_kwargs = {}
|
||||
|
||||
@function.Defun(
|
||||
*self._defun_args(), func_name=self._func_name, **defun_kwargs)
|
||||
*self._input_structure._flat_types, func_name=self._func_name, # pylint: disable=protected-access
|
||||
**defun_kwargs)
|
||||
def tf_data_structured_function_wrapper(*args):
|
||||
"""Wrapper for passing nested structures to and from tf.data functions."""
|
||||
flat_args = []
|
||||
for arg, arg_class, arg_shape, arg_type in zip(
|
||||
args,
|
||||
nest.flatten(self._input_classes),
|
||||
nest.flatten(self._input_shapes),
|
||||
nest.flatten(self._input_types)):
|
||||
# TODO(b/110122868): Add a registration mechanism for new component
|
||||
# types.
|
||||
if arg_class is sparse_tensor_lib.SparseTensor:
|
||||
arg = sparse.deserialize_sparse_tensors(
|
||||
arg, arg_type, arg_shape, arg_class)
|
||||
arg.indices.set_shape([None, arg_shape.ndims])
|
||||
arg.dense_shape.set_shape([arg_shape.ndims])
|
||||
elif isinstance(arg_class, _NestedDatasetComponent):
|
||||
arg = _VariantDataset(arg, arg_class)
|
||||
else:
|
||||
arg.set_shape(arg_shape)
|
||||
flat_args.append(arg)
|
||||
nested_args = nest.pack_sequence_as(self._input_classes, flat_args)
|
||||
# pylint: disable=protected-access
|
||||
nested_args = self._input_structure._from_compatible_tensor_list(args)
|
||||
if not _should_unpack_args(nested_args):
|
||||
nested_args = (nested_args,)
|
||||
|
||||
@ -2054,50 +2048,14 @@ class StructuredFunctionWrapper(object):
|
||||
if isinstance(ret, list):
|
||||
ret = tuple(ret)
|
||||
|
||||
# Convert any `SparseTensorValue`s to `SparseTensor`s and all other
|
||||
# values to tensors.
|
||||
flat_ret = []
|
||||
flat_classes = []
|
||||
flat_shapes = []
|
||||
flat_types = []
|
||||
for t in nest.flatten(ret):
|
||||
# TODO(b/110122868): Add a registration mechanism for new component
|
||||
# types.
|
||||
if sparse_tensor_lib.is_sparse(t):
|
||||
t = sparse_tensor_lib.SparseTensor.from_value(t)
|
||||
flat_ret.append(sparse.serialize_sparse_tensors(t))
|
||||
flat_classes.append(sparse_tensor_lib.SparseTensor)
|
||||
flat_shapes.append(t.get_shape())
|
||||
flat_types.append(t.dtype)
|
||||
elif isinstance(t, DatasetV2):
|
||||
flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access
|
||||
component = _NestedDatasetComponent(t)
|
||||
flat_classes.append(component)
|
||||
flat_shapes.append(component)
|
||||
flat_types.append(component)
|
||||
if t.options() != Options():
|
||||
warnings.warn("Encountered a nested dataset with non-default "
|
||||
"options. These options will not be propagated to "
|
||||
"the outer dataset.")
|
||||
else:
|
||||
try:
|
||||
t = ops.convert_to_tensor(t)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Unsupported return value from function passed to "
|
||||
"%s: %s." % (transformation_name, t))
|
||||
flat_ret.append(t)
|
||||
flat_classes.append(ops.Tensor)
|
||||
flat_shapes.append(t.get_shape())
|
||||
flat_types.append(t.dtype)
|
||||
|
||||
ret = nest.pack_sequence_as(ret, flat_ret)
|
||||
self._output_classes = nest.pack_sequence_as(ret, flat_classes)
|
||||
self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
|
||||
self._output_types = nest.pack_sequence_as(ret, flat_types)
|
||||
try:
|
||||
self._output_structure = structure_lib.Structure.from_value(ret)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Unsupported return value from function passed to "
|
||||
"%s: %s." % (transformation_name, ret))
|
||||
|
||||
_warn_if_collections(transformation_name)
|
||||
|
||||
return flat_ret
|
||||
return self._output_structure._to_tensor_list(ret)
|
||||
|
||||
self._function = tf_data_structured_function_wrapper
|
||||
if add_to_graph:
|
||||
@ -2108,32 +2066,21 @@ class StructuredFunctionWrapper(object):
|
||||
# in case (e.g.) we need to rerun the function.
|
||||
self._function._create_definition_if_needed() # pylint: disable=protected-access
|
||||
|
||||
def _defun_args(self):
|
||||
"""Returns a flat list of `tf.DType` for the input element structure."""
|
||||
ret = []
|
||||
for input_type, input_class in zip(nest.flatten(self._input_types),
|
||||
nest.flatten(self._input_classes)):
|
||||
# TODO(b/110122868): Add a registration mechanism for new component types.
|
||||
if input_class is sparse_tensor_lib.SparseTensor:
|
||||
ret.append(dtypes.variant)
|
||||
elif isinstance(input_class, _NestedDatasetComponent):
|
||||
ret.append(dtypes.variant)
|
||||
else:
|
||||
assert isinstance(input_type, dtypes.DType)
|
||||
ret.append(input_type)
|
||||
return ret
|
||||
@property
|
||||
def output_structure(self):
|
||||
return self._output_structure
|
||||
|
||||
@property
|
||||
def output_classes(self):
|
||||
return self._output_classes
|
||||
return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def output_shapes(self):
|
||||
return self._output_shapes
|
||||
return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def output_types(self):
|
||||
return self._output_types
|
||||
return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access
|
||||
|
||||
@property
|
||||
def function(self):
|
||||
@ -2156,30 +2103,12 @@ def flat_structure(dataset):
|
||||
A dictionary of keyword arguments that can be passed to many Dataset op
|
||||
constructors.
|
||||
"""
|
||||
output_classes = []
|
||||
output_shapes = []
|
||||
output_types = []
|
||||
for output_class, output_shape, output_type in zip(
|
||||
nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes),
|
||||
nest.flatten(dataset.output_types)):
|
||||
if isinstance(output_class, _NestedDatasetComponent):
|
||||
output_classes.append(output_class.output_classes)
|
||||
output_shapes.append(output_shape.output_shapes)
|
||||
output_types.append(output_type.output_types)
|
||||
else:
|
||||
output_classes.append(output_class)
|
||||
output_shapes.append(output_shape)
|
||||
output_types.append(output_type)
|
||||
|
||||
output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes)
|
||||
output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes)
|
||||
output_types = nest.pack_sequence_as(dataset.output_types, output_types)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
structure = structure_lib.Structure._from_legacy_structure(
|
||||
dataset.output_types, dataset.output_shapes, dataset.output_classes)
|
||||
return {
|
||||
"output_shapes":
|
||||
nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)),
|
||||
"output_types":
|
||||
nest.flatten(sparse.as_dense_types(output_types, output_classes)),
|
||||
"output_shapes": structure._flat_shapes,
|
||||
"output_types": structure._flat_types,
|
||||
}
|
||||
|
||||
|
||||
@ -2902,11 +2831,13 @@ class FlatMapDataset(UnaryDataset):
|
||||
|
||||
wrapped_func = StructuredFunctionWrapper(
|
||||
map_func, self._transformation_name(), dataset=input_dataset)
|
||||
if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent):
|
||||
if not isinstance(wrapped_func.output_structure, DatasetStructure):
|
||||
raise TypeError("`map_func` must return a `Dataset` object.")
|
||||
self._output_classes = wrapped_func.output_classes.output_classes
|
||||
self._output_types = wrapped_func.output_types.output_types
|
||||
self._output_shapes = wrapped_func.output_shapes.output_shapes
|
||||
# pylint: disable=protected-access
|
||||
element_structure = wrapped_func.output_structure._element_structure
|
||||
self._output_classes = element_structure._to_legacy_output_classes()
|
||||
self._output_types = element_structure._to_legacy_output_types()
|
||||
self._output_shapes = element_structure._to_legacy_output_shapes()
|
||||
self._map_func = wrapped_func.function
|
||||
|
||||
def _as_variant_tensor(self):
|
||||
@ -3048,10 +2979,9 @@ class WindowDataset(UnaryDataset):
|
||||
self._output_classes = nest.pack_sequence_as(
|
||||
input_dataset.output_classes,
|
||||
[
|
||||
_NestedDatasetComponent( # pylint: disable=protected-access
|
||||
output_classes=output_class,
|
||||
output_shapes=output_shape,
|
||||
output_types=output_type)
|
||||
DatasetStructure(
|
||||
structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access
|
||||
output_type, output_shape, output_class))
|
||||
for output_class, output_shape, output_type in zip(
|
||||
nest.flatten(input_dataset.output_classes),
|
||||
nest.flatten(input_dataset.output_shapes),
|
||||
|
@ -183,19 +183,13 @@ class OptionalStructure(structure.Structure):
|
||||
return OptionalStructure(value.value_structure)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
raise NotImplementedError("The `output_types` property is not supported on "
|
||||
"structured objects containing an `Optional`. "
|
||||
"Use the corresponding `structure` property.")
|
||||
return self
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
raise NotImplementedError("The `output_shapes` property is not supported on"
|
||||
" structured objects containing an `Optional`. "
|
||||
"Use the corresponding `structure` property.")
|
||||
return self
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
raise NotImplementedError("The `output_classes` property is not supported "
|
||||
"on structured objects containing an `Optional`. "
|
||||
"Use the corresponding `structure` property.")
|
||||
return self
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
@ -208,14 +208,16 @@ class Structure(object):
|
||||
flat_ret = []
|
||||
for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
|
||||
flat_classes):
|
||||
if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
|
||||
if isinstance(flat_class, Structure):
|
||||
flat_ret.append(flat_class)
|
||||
elif issubclass(flat_class, sparse_tensor_lib.SparseTensor):
|
||||
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
|
||||
elif issubclass(flat_class, ops.Tensor):
|
||||
flat_ret.append(TensorStructure(flat_type, flat_shape))
|
||||
else:
|
||||
# NOTE(mrry): Since legacy structures produced by iterators only
|
||||
# comprise Tensors, SparseTensors, and nests, we do not need to support
|
||||
# all structure types here.
|
||||
# comprise Tensors, SparseTensors, and nests, we do not need to
|
||||
# support all structure types here.
|
||||
raise TypeError(
|
||||
"Could not build a structure for output class %r" % flat_type)
|
||||
|
||||
@ -381,6 +383,13 @@ class TensorStructure(Structure):
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
|
||||
# op here and return that, instead of mutating the input's shape using
|
||||
# `Tensor.set_shape()`. However, that would add extra ops on the arguments
|
||||
# of each `tf.data` function, which could impact performance. When this
|
||||
# bug is resolved, we should be able to add the `ensure_shape()` ops and
|
||||
# optimize them away using contextual shape information.
|
||||
flat_value[0].set_shape(self._shape)
|
||||
return flat_value[0]
|
||||
|
||||
@staticmethod
|
||||
@ -406,7 +415,11 @@ class SparseTensorStructure(Structure):
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
return [tensor_shape.vector(3)]
|
||||
# NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
|
||||
# but a `SparseTensorStructure` can also represent a batch of boxed
|
||||
# `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.),
|
||||
# so the flat shape must be unknown.
|
||||
return [tensor_shape.unknown_shape(None)]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
@ -428,8 +441,11 @@ class SparseTensorStructure(Structure):
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
return sparse_ops.deserialize_sparse(
|
||||
ret = sparse_ops.deserialize_sparse(
|
||||
flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
|
||||
ret.indices.set_shape([None, self._dense_shape.ndims])
|
||||
ret.dense_shape.set_shape([self._dense_shape.ndims])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
|
@ -44,7 +44,7 @@ class StructureTest(test.TestCase, parameterized.TestCase):
|
||||
[dtypes.float32], [[]]),
|
||||
(lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
|
||||
structure.SparseTensorStructure, [dtypes.variant], [[3]]),
|
||||
structure.SparseTensorStructure, [dtypes.variant], [None]),
|
||||
(lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
|
||||
structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
|
||||
(lambda: {
|
||||
@ -58,14 +58,17 @@ class StructureTest(test.TestCase, parameterized.TestCase):
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
|
||||
}, structure.NestedStructure,
|
||||
[dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
|
||||
[dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
|
||||
def testFlatStructure(self, value_fn, expected_structure, expected_types,
|
||||
expected_shapes):
|
||||
value = value_fn()
|
||||
s = structure.Structure.from_value(value)
|
||||
self.assertIsInstance(s, expected_structure)
|
||||
self.assertEqual(expected_types, s._flat_types)
|
||||
self.assertEqual(expected_shapes, s._flat_shapes)
|
||||
for expected, actual in zip(expected_shapes, s._flat_shapes):
|
||||
self.assertTrue(actual.is_compatible_with(expected))
|
||||
self.assertTrue(
|
||||
tensor_shape.as_shape(expected).is_compatible_with(actual))
|
||||
|
||||
@parameterized.parameters(
|
||||
(lambda: constant_op.constant(37.0), lambda: [
|
||||
|
Loading…
Reference in New Issue
Block a user