[tf.data] Removing NestedStructure
in favor of nested structure of (flat) TypeSpec
s
PiperOrigin-RevId: 254862137
This commit is contained in:
parent
fb6f20cfbd
commit
1d29b5c344
@ -28,7 +28,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
||||
@@DatasetStructure
|
||||
@@DistributeOptions
|
||||
@@MapVectorizationOptions
|
||||
@@NestedStructure
|
||||
@@OptimizationOptions
|
||||
@@Optional
|
||||
@@OptionalStructure
|
||||
@ -137,7 +136,6 @@ from tensorflow.python.data.ops.dataset_ops import to_variant
|
||||
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
||||
from tensorflow.python.data.ops.optional_ops import Optional
|
||||
from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
||||
from tensorflow.python.data.util.structure import NestedStructure
|
||||
from tensorflow.python.data.util.structure import RaggedTensorStructure
|
||||
from tensorflow.python.data.util.structure import SparseTensorStructure
|
||||
from tensorflow.python.data.util.structure import Structure
|
||||
|
@ -170,164 +170,6 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
|
||||
return nest.pack_sequence_as(output_classes, flat_ret)
|
||||
|
||||
|
||||
# TODO(b/133606651) Once cl/253160704 is part of tf-nightly, remove the use
|
||||
# of NestedStructure from TFF and remove this class.
|
||||
@tf_export("data.experimental.NestedStructure")
|
||||
class NestedStructure(type_spec.BatchableTypeSpec):
|
||||
"""Represents a nested structure in which each leaf is a `TypeSpec`."""
|
||||
|
||||
# NOTE(edloper): This class makes extensive use of non-public TypeSpec
|
||||
# methods, so we disable the protected-access lint warning once here.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
__slots__ = ["_nested_structure", "_flat_nested_structure",
|
||||
"__flat_tensor_specs"]
|
||||
|
||||
def __init__(self, nested_structure):
|
||||
self._nested_structure = nested_structure
|
||||
self._flat_nested_structure = nest.flatten(nested_structure)
|
||||
self.__flat_tensor_specs = []
|
||||
for s in self._flat_nested_structure:
|
||||
if not isinstance(s, type_spec.TypeSpec):
|
||||
raise TypeError("nested_structure must be a (potentially nested) tuple "
|
||||
"or dictionary of TypeSpec objects.")
|
||||
self.__flat_tensor_specs.extend(s._flat_tensor_specs)
|
||||
|
||||
value_type = property(lambda self: type(self._nested_structure))
|
||||
|
||||
def _serialize(self):
|
||||
return self._nested_structure
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, nested_structure):
|
||||
return cls(nested_structure)
|
||||
|
||||
def most_specific_compatible_type(self, other):
|
||||
if type(self) is not type(other):
|
||||
raise ValueError("Incompatible types")
|
||||
return self._deserialize(
|
||||
nest.map_structure(lambda a, b: a.most_specific_compatible_type(b),
|
||||
self._nested_structure, other._nested_structure))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, NestedStructure):
|
||||
return False
|
||||
try:
|
||||
nest.assert_same_structure(self._nested_structure,
|
||||
other._nested_structure)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
return (nest.flatten(self._nested_structure) ==
|
||||
nest.flatten(other._nested_structure))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(nest.flatten(self._nested_structure)))
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
if not isinstance(other, NestedStructure):
|
||||
return False
|
||||
try:
|
||||
nest.assert_same_structure(self._nested_structure,
|
||||
other._nested_structure)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
return all(
|
||||
substructure.is_compatible_with(other_substructure)
|
||||
for substructure, other_substructure in zip(
|
||||
nest.flatten(self._nested_structure),
|
||||
nest.flatten(other._nested_structure)))
|
||||
|
||||
_component_specs = property(lambda self: self._nested_structure)
|
||||
_flat_tensor_specs = property(lambda self: self.__flat_tensor_specs)
|
||||
|
||||
def _to_components(self, value):
|
||||
return nest.map_structure_up_to(
|
||||
self._nested_structure, lambda t, v: t._to_components(v),
|
||||
self._nested_structure, value)
|
||||
|
||||
def _from_components(self, value):
|
||||
return nest.map_structure_up_to(
|
||||
self._nested_structure, lambda t, v: t._from_components(v),
|
||||
self._nested_structure, value)
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return self.__value_to_tensors(
|
||||
value, lambda struct, val: struct._to_tensor_list(val))
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
return self.__value_to_tensors(
|
||||
value, lambda struct, val: struct._to_batched_tensor_list(val))
|
||||
|
||||
def __value_to_tensors(self, value, to_tensor_list_fn):
|
||||
ret = []
|
||||
|
||||
try:
|
||||
flat_value = nest.flatten_up_to(self._nested_structure, value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError("The value %r is not compatible with the nested "
|
||||
"structure %r." % (value, self._nested_structure))
|
||||
|
||||
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
|
||||
if not structure.is_compatible_with(
|
||||
type_spec.type_spec_from_value(sub_value)):
|
||||
raise ValueError("Component value %r is not compatible with the nested "
|
||||
"structure %r." % (sub_value, structure))
|
||||
ret.extend(to_tensor_list_fn(structure, sub_value))
|
||||
return ret
|
||||
|
||||
def _from_tensor_list(self, value):
|
||||
return self.__tensors_to_value(
|
||||
value, lambda struct, val: struct._from_tensor_list(val))
|
||||
|
||||
def _from_compatible_tensor_list(self, value):
|
||||
return self.__tensors_to_value(
|
||||
value, lambda struct, val: struct._from_compatible_tensor_list(val))
|
||||
|
||||
def __tensors_to_value(self, flat_value, from_tensor_list_fn):
|
||||
if len(flat_value) != len(self._flat_tensor_specs):
|
||||
raise ValueError("Expected %d flat values in NestedStructure but got %d."
|
||||
% (len(self._flat_tensor_specs), len(flat_value)))
|
||||
flat_ret = []
|
||||
i = 0
|
||||
for structure in self._flat_nested_structure:
|
||||
num_flat_values = len(structure._flat_tensor_specs)
|
||||
sub_value = flat_value[i:i + num_flat_values]
|
||||
flat_ret.append(from_tensor_list_fn(structure, sub_value))
|
||||
i += num_flat_values
|
||||
|
||||
return nest.pack_sequence_as(self._nested_structure, flat_ret)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
flat_nested_structure = [
|
||||
type_spec.type_spec_from_value(sub_value)
|
||||
for sub_value in nest.flatten(value)
|
||||
]
|
||||
return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return nest.map_structure(
|
||||
lambda s: s._to_legacy_output_types(), self._nested_structure)
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return nest.map_structure(
|
||||
lambda s: s._to_legacy_output_shapes(), self._nested_structure)
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return nest.map_structure(
|
||||
lambda s: s._to_legacy_output_classes(), self._nested_structure)
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return NestedStructure(nest.map_structure(
|
||||
lambda s: s._batch(batch_size), self._nested_structure))
|
||||
|
||||
def _unbatch(self):
|
||||
return NestedStructure(nest.map_structure(
|
||||
lambda s: s._unbatch(), self._nested_structure))
|
||||
|
||||
|
||||
def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
|
||||
"""Returns an element constructed from the given spec and tensor list.
|
||||
|
||||
|
@ -25,7 +25,6 @@ import numpy as np
|
||||
from tensorflow.python import tf2
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.ops import iterator_ops
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.distribute import distribution_strategy_context
|
||||
from tensorflow.python.distribute import multi_worker_util
|
||||
from tensorflow.python.eager import context
|
||||
@ -2442,10 +2441,6 @@ class Model(network.Network):
|
||||
# code.
|
||||
if isinstance(x, dataset_ops.DatasetV2):
|
||||
x_shapes = dataset_ops.get_structure(x)
|
||||
# TODO(momernick): Remove this once NestedStructure goes away. Right
|
||||
# now, Dataset outputs one of these instead of an actual python structure.
|
||||
if isinstance(x_shapes, structure.NestedStructure):
|
||||
x_shapes = x_shapes._component_specs # pylint: disable=protected-access
|
||||
if isinstance(x_shapes, tuple):
|
||||
# If the output of a Dataset is a tuple, we assume it's either of the
|
||||
# form (x_data, y_data) or (x_data, y_data, sample_weights). In either
|
||||
@ -2460,10 +2455,6 @@ class Model(network.Network):
|
||||
x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
|
||||
x_shapes = nest.map_structure(type_spec.type_spec_from_value, x)
|
||||
|
||||
# If the inputs are still a NestedStructure, then we have a dict-input to
|
||||
# this model. We can't yet validate this. (It's only relevant for feature
|
||||
# columns).
|
||||
if not isinstance(x_shapes, structure.NestedStructure):
|
||||
flat_inputs = nest.flatten(x_shapes, expand_composites=False)
|
||||
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
|
||||
for (a, b) in zip(flat_inputs, flat_expected_inputs):
|
||||
|
@ -1,27 +0,0 @@
|
||||
path: "tensorflow.data.experimental.NestedStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -28,10 +28,6 @@ tf_module {
|
||||
name: "MapVectorizationOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "NestedStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "OptimizationOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
|
@ -1,27 +0,0 @@
|
||||
path: "tensorflow.data.experimental.NestedStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -28,10 +28,6 @@ tf_module {
|
||||
name: "MapVectorizationOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "NestedStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "OptimizationOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user