[tf.data] Removing NestedStructure in favor of nested structure of (flat) TypeSpecs

PiperOrigin-RevId: 254862137
This commit is contained in:
Jiri Simsa 2019-06-24 16:35:30 -07:00 committed by TensorFlower Gardener
parent fb6f20cfbd
commit 1d29b5c344
7 changed files with 4 additions and 235 deletions

View File

@ -28,7 +28,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@DatasetStructure
@@DistributeOptions
@@MapVectorizationOptions
@@NestedStructure
@@OptimizationOptions
@@Optional
@@OptionalStructure
@ -137,7 +136,6 @@ from tensorflow.python.data.ops.dataset_ops import to_variant
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
from tensorflow.python.data.ops.optional_ops import Optional
from tensorflow.python.data.ops.optional_ops import OptionalStructure
from tensorflow.python.data.util.structure import NestedStructure
from tensorflow.python.data.util.structure import RaggedTensorStructure
from tensorflow.python.data.util.structure import SparseTensorStructure
from tensorflow.python.data.util.structure import Structure

View File

@ -170,164 +170,6 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
return nest.pack_sequence_as(output_classes, flat_ret)
# TODO(b/133606651) Once cl/253160704 is part of tf-nightly, remove the use
# of NestedStructure from TFF and remove this class.
@tf_export("data.experimental.NestedStructure")
class NestedStructure(type_spec.BatchableTypeSpec):
"""Represents a nested structure in which each leaf is a `TypeSpec`."""
# NOTE(edloper): This class makes extensive use of non-public TypeSpec
# methods, so we disable the protected-access lint warning once here.
# pylint: disable=protected-access
__slots__ = ["_nested_structure", "_flat_nested_structure",
"__flat_tensor_specs"]
def __init__(self, nested_structure):
self._nested_structure = nested_structure
self._flat_nested_structure = nest.flatten(nested_structure)
self.__flat_tensor_specs = []
for s in self._flat_nested_structure:
if not isinstance(s, type_spec.TypeSpec):
raise TypeError("nested_structure must be a (potentially nested) tuple "
"or dictionary of TypeSpec objects.")
self.__flat_tensor_specs.extend(s._flat_tensor_specs)
value_type = property(lambda self: type(self._nested_structure))
def _serialize(self):
return self._nested_structure
@classmethod
def _deserialize(cls, nested_structure):
return cls(nested_structure)
def most_specific_compatible_type(self, other):
if type(self) is not type(other):
raise ValueError("Incompatible types")
return self._deserialize(
nest.map_structure(lambda a, b: a.most_specific_compatible_type(b),
self._nested_structure, other._nested_structure))
def __eq__(self, other):
if not isinstance(other, NestedStructure):
return False
try:
nest.assert_same_structure(self._nested_structure,
other._nested_structure)
except (ValueError, TypeError):
return False
return (nest.flatten(self._nested_structure) ==
nest.flatten(other._nested_structure))
def __hash__(self):
return hash(tuple(nest.flatten(self._nested_structure)))
def is_compatible_with(self, other):
if not isinstance(other, NestedStructure):
return False
try:
nest.assert_same_structure(self._nested_structure,
other._nested_structure)
except (ValueError, TypeError):
return False
# pylint: disable=g-complex-comprehension
return all(
substructure.is_compatible_with(other_substructure)
for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
nest.flatten(other._nested_structure)))
_component_specs = property(lambda self: self._nested_structure)
_flat_tensor_specs = property(lambda self: self.__flat_tensor_specs)
def _to_components(self, value):
return nest.map_structure_up_to(
self._nested_structure, lambda t, v: t._to_components(v),
self._nested_structure, value)
def _from_components(self, value):
return nest.map_structure_up_to(
self._nested_structure, lambda t, v: t._from_components(v),
self._nested_structure, value)
def _to_tensor_list(self, value):
return self.__value_to_tensors(
value, lambda struct, val: struct._to_tensor_list(val))
def _to_batched_tensor_list(self, value):
return self.__value_to_tensors(
value, lambda struct, val: struct._to_batched_tensor_list(val))
def __value_to_tensors(self, value, to_tensor_list_fn):
ret = []
try:
flat_value = nest.flatten_up_to(self._nested_structure, value)
except (ValueError, TypeError):
raise ValueError("The value %r is not compatible with the nested "
"structure %r." % (value, self._nested_structure))
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
if not structure.is_compatible_with(
type_spec.type_spec_from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(to_tensor_list_fn(structure, sub_value))
return ret
def _from_tensor_list(self, value):
return self.__tensors_to_value(
value, lambda struct, val: struct._from_tensor_list(val))
def _from_compatible_tensor_list(self, value):
return self.__tensors_to_value(
value, lambda struct, val: struct._from_compatible_tensor_list(val))
def __tensors_to_value(self, flat_value, from_tensor_list_fn):
if len(flat_value) != len(self._flat_tensor_specs):
raise ValueError("Expected %d flat values in NestedStructure but got %d."
% (len(self._flat_tensor_specs), len(flat_value)))
flat_ret = []
i = 0
for structure in self._flat_nested_structure:
num_flat_values = len(structure._flat_tensor_specs)
sub_value = flat_value[i:i + num_flat_values]
flat_ret.append(from_tensor_list_fn(structure, sub_value))
i += num_flat_values
return nest.pack_sequence_as(self._nested_structure, flat_ret)
@staticmethod
def from_value(value):
flat_nested_structure = [
type_spec.type_spec_from_value(sub_value)
for sub_value in nest.flatten(value)
]
return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
def _to_legacy_output_types(self):
return nest.map_structure(
lambda s: s._to_legacy_output_types(), self._nested_structure)
def _to_legacy_output_shapes(self):
return nest.map_structure(
lambda s: s._to_legacy_output_shapes(), self._nested_structure)
def _to_legacy_output_classes(self):
return nest.map_structure(
lambda s: s._to_legacy_output_classes(), self._nested_structure)
def _batch(self, batch_size):
return NestedStructure(nest.map_structure(
lambda s: s._batch(batch_size), self._nested_structure))
def _unbatch(self):
return NestedStructure(nest.map_structure(
lambda s: s._unbatch(), self._nested_structure))
def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
"""Returns an element constructed from the given spec and tensor list.

View File

@ -25,7 +25,6 @@ import numpy as np
from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context
@ -2442,10 +2441,6 @@ class Model(network.Network):
# code.
if isinstance(x, dataset_ops.DatasetV2):
x_shapes = dataset_ops.get_structure(x)
# TODO(momernick): Remove this once NestedStructure goes away. Right
# now, Dataset outputs one of these instead of an actual python structure.
if isinstance(x_shapes, structure.NestedStructure):
x_shapes = x_shapes._component_specs # pylint: disable=protected-access
if isinstance(x_shapes, tuple):
# If the output of a Dataset is a tuple, we assume it's either of the
# form (x_data, y_data) or (x_data, y_data, sample_weights). In either
@ -2460,10 +2455,6 @@ class Model(network.Network):
x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
x_shapes = nest.map_structure(type_spec.type_spec_from_value, x)
# If the inputs are still a NestedStructure, then we have a dict-input to
# this model. We can't yet validate this. (It's only relevant for feature
# columns).
if not isinstance(x_shapes, structure.NestedStructure):
flat_inputs = nest.flatten(x_shapes, expand_composites=False)
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
for (a, b) in zip(flat_inputs, flat_expected_inputs):

View File

@ -1,27 +0,0 @@
path: "tensorflow.data.experimental.NestedStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -28,10 +28,6 @@ tf_module {
name: "MapVectorizationOptions"
mtype: "<type \'type\'>"
}
member {
name: "NestedStructure"
mtype: "<type \'type\'>"
}
member {
name: "OptimizationOptions"
mtype: "<type \'type\'>"

View File

@ -1,27 +0,0 @@
path: "tensorflow.data.experimental.NestedStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -28,10 +28,6 @@ tf_module {
name: "MapVectorizationOptions"
mtype: "<type \'type\'>"
}
member {
name: "NestedStructure"
mtype: "<type \'type\'>"
}
member {
name: "OptimizationOptions"
mtype: "<type \'type\'>"