Automated rollback of commit fd46ffb7bd

PiperOrigin-RevId: 222421495
This commit is contained in:
Derek Murray 2018-11-21 09:43:16 -08:00 committed by TensorFlower Gardener
parent 80f3b787e4
commit 4f92a46fa8
8 changed files with 202 additions and 188 deletions

View File

@ -165,7 +165,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//tensorflow/python/data/util:structure",
],
)

View File

@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@ -448,7 +449,10 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
def _make_reduce_func(self, reduce_func, input_dataset):
"""Make wrapping defun for reduce_func."""
nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
nested_dataset = dataset_ops.DatasetStructure(
structure.Structure._from_legacy_structure( # pylint: disable=protected-access
input_dataset.output_types, input_dataset.output_shapes,
input_dataset.output_classes))
wrapped_func = dataset_ops.StructuredFunctionWrapper(
reduce_func,
self._transformation_name(),
@ -456,11 +460,13 @@ class _GroupByWindowDataset(dataset_ops.UnaryDataset):
input_shapes=(tensor_shape.scalar(), nested_dataset),
input_types=(dtypes.int64, nested_dataset))
if not isinstance(
wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
wrapped_func.output_structure, dataset_ops.DatasetStructure):
raise TypeError("`reduce_func` must return a `Dataset` object.")
self._output_classes = wrapped_func.output_classes.output_classes
self._output_types = wrapped_func.output_types.output_types
self._output_shapes = wrapped_func.output_shapes.output_shapes
# pylint: disable=protected-access
element_structure = wrapped_func.output_structure._element_structure
self._output_classes = element_structure._to_legacy_output_classes()
self._output_types = element_structure._to_legacy_output_types()
self._output_shapes = element_structure._to_legacy_output_shapes()
self._reduce_func = wrapped_func.function
@property

View File

@ -117,8 +117,12 @@ tf_py_test(
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:optional_ops",
"//tensorflow/python/data/util:structure",
],
)

View File

@ -24,10 +24,14 @@ import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
@ -249,6 +253,63 @@ class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(ds.options().experimental_autotune)
self.assertTrue(ds.options().experimental_filter_fusion)
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
("Tensor", lambda: constant_op.constant(37.0),
structure.TensorStructure(dtypes.float32, [])),
("SparseTensor", lambda: sparse_tensor.SparseTensor(
indices=[[0]], values=constant_op.constant([0], dtype=dtypes.int32),
dense_shape=[1]),
structure.SparseTensorStructure(dtypes.int32, [1])),
("Nest", lambda: {
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))},
structure.NestedStructure({
"a": structure.TensorStructure(dtypes.float32, []),
"b": (structure.TensorStructure(dtypes.string, [1]),
structure.TensorStructure(dtypes.string, []))})),
("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([1, 2, 3])),
dataset_ops.DatasetStructure(
structure.TensorStructure(dtypes.int32, []))),
("Optional", lambda: optional_ops.Optional.from_value(37.0),
optional_ops.OptionalStructure(
structure.TensorStructure(dtypes.float32, []))),
)
def testDatasetStructure(self, tf_value_fn, expected_element_structure):
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
dataset_structure = structure.Structure.from_value(dataset)
self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)
# TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
# the element structure.
self.assertTrue(expected_element_structure.is_compatible_with(
dataset_structure._element_structure))
self.assertTrue(dataset_structure._element_structure.is_compatible_with(
expected_element_structure))
self.assertEqual([dtypes.variant], dataset_structure._flat_types)
self.assertEqual([tensor_shape.scalar()], dataset_structure._flat_shapes)
# Assert that the `Dataset` survives a round-trip via _from_tensor_list()
# and _to_tensor_list().
round_trip_dataset = dataset_structure._from_tensor_list(
dataset_structure._to_tensor_list(dataset))
value = tf_value_fn()
if isinstance(value, dataset_ops.Dataset):
self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
elif isinstance(value, optional_ops.Optional):
self.assertDatasetProduces(
round_trip_dataset.map(lambda opt: opt.get_value()),
[self.evaluate(value.get_value())],
requires_initialization=True)
else:
self.assertDatasetProduces(
round_trip_dataset, [self.evaluate(tf_value_fn())],
requires_initialization=True)
if __name__ == "__main__":
test.main()

View File

@ -31,6 +31,7 @@ from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.data.util import structure as structure_lib
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@ -1868,57 +1869,6 @@ class SparseTensorSliceDataset(DatasetSource):
return (dtypes.int64, self._sparse_tensor.dtype, dtypes.int64)
class _NestedDatasetComponent(object):
"""The structure of a `Dataset` nested in a component of another `Dataset`.
A `StructuredFunctionWrapper` around a function that returns a `Dataset` as
one of its components will have a `NestedDatasetComponent` in the
corresponding position in the `output_classes`, `output_shapes`, and
`output_types` properties.
TODO(b/110122868): Add this class, or something equivalent, to the public API.
We are considering revising the public API for accessing Dataset structure
(`output_classes` etc.) based on experience with nested datasets and other
custom component types.
"""
def __init__(self,
dataset=None,
output_shapes=None,
output_types=None,
output_classes=None):
if dataset is None:
if (output_classes is None or output_shapes is None or
output_types is None):
raise ValueError(
"Either `dataset`, or all of `output_classes`, "
"`output_shapes`, and `output_types` must be specified.")
self._output_classes = output_classes
self._output_shapes = output_shapes
self._output_types = output_types
else:
if not (output_classes is None and output_shapes is None and
output_types is None):
raise ValueError(
"Either `dataset`, or all of `output_classes`, "
"`output_shapes`, and `output_types` must be specified.")
self._output_classes = dataset.output_classes
self._output_shapes = dataset.output_shapes
self._output_types = dataset.output_types
@property
def output_classes(self):
return self._output_classes
@property
def output_shapes(self):
return self._output_shapes
@property
def output_types(self):
return self._output_types
class _VariantDataset(DatasetV2):
"""A Dataset wrapper around a `tf.variant`-typed function argument."""
@ -1935,15 +1885,73 @@ class _VariantDataset(DatasetV2):
@property
def output_classes(self):
return self._structure.output_classes
return self._structure._to_legacy_output_classes() # pylint: disable=protected-access
@property
def output_shapes(self):
return self._structure.output_shapes
return self._structure._to_legacy_output_shapes() # pylint: disable=protected-access
@property
def output_types(self):
return self._structure.output_types
return self._structure._to_legacy_output_types() # pylint: disable=protected-access
class DatasetStructure(structure_lib.Structure):
"""Represents a `Dataset` of structured values."""
def __init__(self, element_structure):
self._element_structure = element_structure
@property
def _flat_shapes(self):
return [tensor_shape.scalar()]
@property
def _flat_types(self):
return [dtypes.variant]
def is_compatible_with(self, other):
# pylint: disable=protected-access
return (isinstance(other, DatasetStructure) and
self._element_structure.is_compatible_with(
other._element_structure))
def _to_tensor_list(self, value):
return [value._as_variant_tensor()] # pylint: disable=protected-access
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
raise ValueError(
"DatasetStructure corresponds to a single tf.variant scalar.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
# pylint: disable=protected-access
return _VariantDataset(flat_value[0], self._element_structure)
@staticmethod
def from_value(value):
# TODO(b/110122868): We can simplify this when a `Dataset` object has a
# `Structure`-valued property.
element_structure = structure_lib.Structure._from_legacy_structure(
value.output_types, value.output_shapes, value.output_classes)
return DatasetStructure(element_structure)
def _to_legacy_output_types(self):
return self
def _to_legacy_output_shapes(self):
return self
def _to_legacy_output_classes(self):
return self
# pylint: disable=protected-access
structure_lib.Structure._register_custom_converter(DatasetV2,
DatasetStructure.from_value)
# pylint: enable=protected-access
class StructuredFunctionWrapper(object):
@ -2001,6 +2009,9 @@ class StructuredFunctionWrapper(object):
self._input_types = dataset.output_types
self._input_classes = dataset.output_classes
self._input_structure = structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access
self._input_types, self._input_shapes, self._input_classes)
self._transformation_name = transformation_name
readable_transformation_name = transformation_name.replace(
".", "_")[:-2] if len(transformation_name) > 2 else ""
@ -2008,35 +2019,18 @@ class StructuredFunctionWrapper(object):
readable_transformation_name,
function_utils.get_func_name(func),
str(ops.uid())
])
if defun_kwargs is None:
defun_kwargs = {}
@function.Defun(
*self._defun_args(), func_name=self._func_name, **defun_kwargs)
*self._input_structure._flat_types, func_name=self._func_name, # pylint: disable=protected-access
**defun_kwargs)
def tf_data_structured_function_wrapper(*args):
"""Wrapper for passing nested structures to and from tf.data functions."""
flat_args = []
for arg, arg_class, arg_shape, arg_type in zip(
args,
nest.flatten(self._input_classes),
nest.flatten(self._input_shapes),
nest.flatten(self._input_types)):
# TODO(b/110122868): Add a registration mechanism for new component
# types.
if arg_class is sparse_tensor_lib.SparseTensor:
arg = sparse.deserialize_sparse_tensors(
arg, arg_type, arg_shape, arg_class)
arg.indices.set_shape([None, arg_shape.ndims])
arg.dense_shape.set_shape([arg_shape.ndims])
elif isinstance(arg_class, _NestedDatasetComponent):
arg = _VariantDataset(arg, arg_class)
else:
arg.set_shape(arg_shape)
flat_args.append(arg)
nested_args = nest.pack_sequence_as(self._input_classes, flat_args)
# pylint: disable=protected-access
nested_args = self._input_structure._from_compatible_tensor_list(args)
if not _should_unpack_args(nested_args):
nested_args = (nested_args,)
@ -2054,50 +2048,14 @@ class StructuredFunctionWrapper(object):
if isinstance(ret, list):
ret = tuple(ret)
# Convert any `SparseTensorValue`s to `SparseTensor`s and all other
# values to tensors.
flat_ret = []
flat_classes = []
flat_shapes = []
flat_types = []
for t in nest.flatten(ret):
# TODO(b/110122868): Add a registration mechanism for new component
# types.
if sparse_tensor_lib.is_sparse(t):
t = sparse_tensor_lib.SparseTensor.from_value(t)
flat_ret.append(sparse.serialize_sparse_tensors(t))
flat_classes.append(sparse_tensor_lib.SparseTensor)
flat_shapes.append(t.get_shape())
flat_types.append(t.dtype)
elif isinstance(t, DatasetV2):
flat_ret.append(t._as_variant_tensor()) # pylint: disable=protected-access
component = _NestedDatasetComponent(t)
flat_classes.append(component)
flat_shapes.append(component)
flat_types.append(component)
if t.options() != Options():
warnings.warn("Encountered a nested dataset with non-default "
"options. These options will not be propagated to "
"the outer dataset.")
else:
try:
t = ops.convert_to_tensor(t)
except (ValueError, TypeError):
raise TypeError("Unsupported return value from function passed to "
"%s: %s." % (transformation_name, t))
flat_ret.append(t)
flat_classes.append(ops.Tensor)
flat_shapes.append(t.get_shape())
flat_types.append(t.dtype)
ret = nest.pack_sequence_as(ret, flat_ret)
self._output_classes = nest.pack_sequence_as(ret, flat_classes)
self._output_shapes = nest.pack_sequence_as(ret, flat_shapes)
self._output_types = nest.pack_sequence_as(ret, flat_types)
try:
self._output_structure = structure_lib.Structure.from_value(ret)
except (ValueError, TypeError):
raise TypeError("Unsupported return value from function passed to "
"%s: %s." % (transformation_name, ret))
_warn_if_collections(transformation_name)
return flat_ret
return self._output_structure._to_tensor_list(ret)
self._function = tf_data_structured_function_wrapper
if add_to_graph:
@ -2108,32 +2066,21 @@ class StructuredFunctionWrapper(object):
# in case (e.g.) we need to rerun the function.
self._function._create_definition_if_needed() # pylint: disable=protected-access
def _defun_args(self):
"""Returns a flat list of `tf.DType` for the input element structure."""
ret = []
for input_type, input_class in zip(nest.flatten(self._input_types),
nest.flatten(self._input_classes)):
# TODO(b/110122868): Add a registration mechanism for new component types.
if input_class is sparse_tensor_lib.SparseTensor:
ret.append(dtypes.variant)
elif isinstance(input_class, _NestedDatasetComponent):
ret.append(dtypes.variant)
else:
assert isinstance(input_type, dtypes.DType)
ret.append(input_type)
return ret
@property
def output_structure(self):
return self._output_structure
@property
def output_classes(self):
return self._output_classes
return self._output_structure._to_legacy_output_classes() # pylint: disable=protected-access
@property
def output_shapes(self):
return self._output_shapes
return self._output_structure._to_legacy_output_shapes() # pylint: disable=protected-access
@property
def output_types(self):
return self._output_types
return self._output_structure._to_legacy_output_types() # pylint: disable=protected-access
@property
def function(self):
@ -2156,30 +2103,12 @@ def flat_structure(dataset):
A dictionary of keyword arguments that can be passed to many Dataset op
constructors.
"""
output_classes = []
output_shapes = []
output_types = []
for output_class, output_shape, output_type in zip(
nest.flatten(dataset.output_classes), nest.flatten(dataset.output_shapes),
nest.flatten(dataset.output_types)):
if isinstance(output_class, _NestedDatasetComponent):
output_classes.append(output_class.output_classes)
output_shapes.append(output_shape.output_shapes)
output_types.append(output_type.output_types)
else:
output_classes.append(output_class)
output_shapes.append(output_shape)
output_types.append(output_type)
output_classes = nest.pack_sequence_as(dataset.output_classes, output_classes)
output_shapes = nest.pack_sequence_as(dataset.output_shapes, output_shapes)
output_types = nest.pack_sequence_as(dataset.output_types, output_types)
# pylint: disable=protected-access
structure = structure_lib.Structure._from_legacy_structure(
dataset.output_types, dataset.output_shapes, dataset.output_classes)
return {
"output_shapes":
nest.flatten(sparse.as_dense_shapes(output_shapes, output_classes)),
"output_types":
nest.flatten(sparse.as_dense_types(output_types, output_classes)),
"output_shapes": structure._flat_shapes,
"output_types": structure._flat_types,
}
@ -2902,11 +2831,13 @@ class FlatMapDataset(UnaryDataset):
wrapped_func = StructuredFunctionWrapper(
map_func, self._transformation_name(), dataset=input_dataset)
if not isinstance(wrapped_func.output_classes, _NestedDatasetComponent):
if not isinstance(wrapped_func.output_structure, DatasetStructure):
raise TypeError("`map_func` must return a `Dataset` object.")
self._output_classes = wrapped_func.output_classes.output_classes
self._output_types = wrapped_func.output_types.output_types
self._output_shapes = wrapped_func.output_shapes.output_shapes
# pylint: disable=protected-access
element_structure = wrapped_func.output_structure._element_structure
self._output_classes = element_structure._to_legacy_output_classes()
self._output_types = element_structure._to_legacy_output_types()
self._output_shapes = element_structure._to_legacy_output_shapes()
self._map_func = wrapped_func.function
def _as_variant_tensor(self):
@ -3048,10 +2979,9 @@ class WindowDataset(UnaryDataset):
self._output_classes = nest.pack_sequence_as(
input_dataset.output_classes,
[
_NestedDatasetComponent( # pylint: disable=protected-access
output_classes=output_class,
output_shapes=output_shape,
output_types=output_type)
DatasetStructure(
structure_lib.Structure._from_legacy_structure( # pylint: disable=protected-access
output_type, output_shape, output_class))
for output_class, output_shape, output_type in zip(
nest.flatten(input_dataset.output_classes),
nest.flatten(input_dataset.output_shapes),

View File

@ -183,19 +183,13 @@ class OptionalStructure(structure.Structure):
return OptionalStructure(value.value_structure)
def _to_legacy_output_types(self):
raise NotImplementedError("The `output_types` property is not supported on "
"structured objects containing an `Optional`. "
"Use the corresponding `structure` property.")
return self
def _to_legacy_output_shapes(self):
raise NotImplementedError("The `output_shapes` property is not supported on"
" structured objects containing an `Optional`. "
"Use the corresponding `structure` property.")
return self
def _to_legacy_output_classes(self):
raise NotImplementedError("The `output_classes` property is not supported "
"on structured objects containing an `Optional`. "
"Use the corresponding `structure` property.")
return self
# pylint: disable=protected-access

View File

@ -208,14 +208,16 @@ class Structure(object):
flat_ret = []
for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
flat_classes):
if issubclass(flat_class, sparse_tensor_lib.SparseTensor):
if isinstance(flat_class, Structure):
flat_ret.append(flat_class)
elif issubclass(flat_class, sparse_tensor_lib.SparseTensor):
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
elif issubclass(flat_class, ops.Tensor):
flat_ret.append(TensorStructure(flat_type, flat_shape))
else:
# NOTE(mrry): Since legacy structures produced by iterators only
# comprise Tensors, SparseTensors, and nests, we do not need to support
# all structure types here.
# comprise Tensors, SparseTensors, and nests, we do not need to
# support all structure types here.
raise TypeError(
"Could not build a structure for output class %r" % flat_type)
@ -381,6 +383,13 @@ class TensorStructure(Structure):
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
# op here and return that, instead of mutating the input's shape using
# `Tensor.set_shape()`. However, that would add extra ops on the arguments
# of each `tf.data` function, which could impact performance. When this
# bug is resolved, we should be able to add the `ensure_shape()` ops and
# optimize them away using contextual shape information.
flat_value[0].set_shape(self._shape)
return flat_value[0]
@staticmethod
@ -406,7 +415,11 @@ class SparseTensorStructure(Structure):
@property
def _flat_shapes(self):
return [tensor_shape.vector(3)]
# NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
# but a `SparseTensorStructure` can also represent a batch of boxed
# `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.),
# so the flat shape must be unknown.
return [tensor_shape.unknown_shape(None)]
@property
def _flat_types(self):
@ -428,8 +441,11 @@ class SparseTensorStructure(Structure):
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
return sparse_ops.deserialize_sparse(
ret = sparse_ops.deserialize_sparse(
flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
ret.indices.set_shape([None, self._dense_shape.ndims])
ret.dense_shape.set_shape([self._dense_shape.ndims])
return ret
@staticmethod
def from_value(value):

View File

@ -44,7 +44,7 @@ class StructureTest(test.TestCase, parameterized.TestCase):
[dtypes.float32], [[]]),
(lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
structure.SparseTensorStructure, [dtypes.variant], [[3]]),
structure.SparseTensorStructure, [dtypes.variant], [None]),
(lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
(lambda: {
@ -58,14 +58,17 @@ class StructureTest(test.TestCase, parameterized.TestCase):
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
}, structure.NestedStructure,
[dtypes.float32, dtypes.variant, dtypes.variant], [[], [3], [3]]))
[dtypes.float32, dtypes.variant, dtypes.variant], [[], None, None]))
def testFlatStructure(self, value_fn, expected_structure, expected_types,
expected_shapes):
value = value_fn()
s = structure.Structure.from_value(value)
self.assertIsInstance(s, expected_structure)
self.assertEqual(expected_types, s._flat_types)
self.assertEqual(expected_shapes, s._flat_shapes)
for expected, actual in zip(expected_shapes, s._flat_shapes):
self.assertTrue(actual.is_compatible_with(expected))
self.assertTrue(
tensor_shape.as_shape(expected).is_compatible_with(actual))
@parameterized.parameters(
(lambda: constant_op.constant(37.0), lambda: [