[tf.data] Removing NestedStructure
in favor of nested structure of (flat) TypeSpec
s
PiperOrigin-RevId: 254862137
This commit is contained in:
parent
fb6f20cfbd
commit
1d29b5c344
@ -28,7 +28,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
|
|||||||
@@DatasetStructure
|
@@DatasetStructure
|
||||||
@@DistributeOptions
|
@@DistributeOptions
|
||||||
@@MapVectorizationOptions
|
@@MapVectorizationOptions
|
||||||
@@NestedStructure
|
|
||||||
@@OptimizationOptions
|
@@OptimizationOptions
|
||||||
@@Optional
|
@@Optional
|
||||||
@@OptionalStructure
|
@@OptionalStructure
|
||||||
@ -137,7 +136,6 @@ from tensorflow.python.data.ops.dataset_ops import to_variant
|
|||||||
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
|
||||||
from tensorflow.python.data.ops.optional_ops import Optional
|
from tensorflow.python.data.ops.optional_ops import Optional
|
||||||
from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
from tensorflow.python.data.ops.optional_ops import OptionalStructure
|
||||||
from tensorflow.python.data.util.structure import NestedStructure
|
|
||||||
from tensorflow.python.data.util.structure import RaggedTensorStructure
|
from tensorflow.python.data.util.structure import RaggedTensorStructure
|
||||||
from tensorflow.python.data.util.structure import SparseTensorStructure
|
from tensorflow.python.data.util.structure import SparseTensorStructure
|
||||||
from tensorflow.python.data.util.structure import Structure
|
from tensorflow.python.data.util.structure import Structure
|
||||||
|
@ -170,164 +170,6 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
|
|||||||
return nest.pack_sequence_as(output_classes, flat_ret)
|
return nest.pack_sequence_as(output_classes, flat_ret)
|
||||||
|
|
||||||
|
|
||||||
# TODO(b/133606651) Once cl/253160704 is part of tf-nightly, remove the use
|
|
||||||
# of NestedStructure from TFF and remove this class.
|
|
||||||
@tf_export("data.experimental.NestedStructure")
|
|
||||||
class NestedStructure(type_spec.BatchableTypeSpec):
|
|
||||||
"""Represents a nested structure in which each leaf is a `TypeSpec`."""
|
|
||||||
|
|
||||||
# NOTE(edloper): This class makes extensive use of non-public TypeSpec
|
|
||||||
# methods, so we disable the protected-access lint warning once here.
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
|
|
||||||
__slots__ = ["_nested_structure", "_flat_nested_structure",
|
|
||||||
"__flat_tensor_specs"]
|
|
||||||
|
|
||||||
def __init__(self, nested_structure):
|
|
||||||
self._nested_structure = nested_structure
|
|
||||||
self._flat_nested_structure = nest.flatten(nested_structure)
|
|
||||||
self.__flat_tensor_specs = []
|
|
||||||
for s in self._flat_nested_structure:
|
|
||||||
if not isinstance(s, type_spec.TypeSpec):
|
|
||||||
raise TypeError("nested_structure must be a (potentially nested) tuple "
|
|
||||||
"or dictionary of TypeSpec objects.")
|
|
||||||
self.__flat_tensor_specs.extend(s._flat_tensor_specs)
|
|
||||||
|
|
||||||
value_type = property(lambda self: type(self._nested_structure))
|
|
||||||
|
|
||||||
def _serialize(self):
|
|
||||||
return self._nested_structure
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _deserialize(cls, nested_structure):
|
|
||||||
return cls(nested_structure)
|
|
||||||
|
|
||||||
def most_specific_compatible_type(self, other):
|
|
||||||
if type(self) is not type(other):
|
|
||||||
raise ValueError("Incompatible types")
|
|
||||||
return self._deserialize(
|
|
||||||
nest.map_structure(lambda a, b: a.most_specific_compatible_type(b),
|
|
||||||
self._nested_structure, other._nested_structure))
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if not isinstance(other, NestedStructure):
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
nest.assert_same_structure(self._nested_structure,
|
|
||||||
other._nested_structure)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return False
|
|
||||||
return (nest.flatten(self._nested_structure) ==
|
|
||||||
nest.flatten(other._nested_structure))
|
|
||||||
|
|
||||||
def __hash__(self):
|
|
||||||
return hash(tuple(nest.flatten(self._nested_structure)))
|
|
||||||
|
|
||||||
def is_compatible_with(self, other):
|
|
||||||
if not isinstance(other, NestedStructure):
|
|
||||||
return False
|
|
||||||
try:
|
|
||||||
nest.assert_same_structure(self._nested_structure,
|
|
||||||
other._nested_structure)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# pylint: disable=g-complex-comprehension
|
|
||||||
return all(
|
|
||||||
substructure.is_compatible_with(other_substructure)
|
|
||||||
for substructure, other_substructure in zip(
|
|
||||||
nest.flatten(self._nested_structure),
|
|
||||||
nest.flatten(other._nested_structure)))
|
|
||||||
|
|
||||||
_component_specs = property(lambda self: self._nested_structure)
|
|
||||||
_flat_tensor_specs = property(lambda self: self.__flat_tensor_specs)
|
|
||||||
|
|
||||||
def _to_components(self, value):
|
|
||||||
return nest.map_structure_up_to(
|
|
||||||
self._nested_structure, lambda t, v: t._to_components(v),
|
|
||||||
self._nested_structure, value)
|
|
||||||
|
|
||||||
def _from_components(self, value):
|
|
||||||
return nest.map_structure_up_to(
|
|
||||||
self._nested_structure, lambda t, v: t._from_components(v),
|
|
||||||
self._nested_structure, value)
|
|
||||||
|
|
||||||
def _to_tensor_list(self, value):
|
|
||||||
return self.__value_to_tensors(
|
|
||||||
value, lambda struct, val: struct._to_tensor_list(val))
|
|
||||||
|
|
||||||
def _to_batched_tensor_list(self, value):
|
|
||||||
return self.__value_to_tensors(
|
|
||||||
value, lambda struct, val: struct._to_batched_tensor_list(val))
|
|
||||||
|
|
||||||
def __value_to_tensors(self, value, to_tensor_list_fn):
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
flat_value = nest.flatten_up_to(self._nested_structure, value)
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
raise ValueError("The value %r is not compatible with the nested "
|
|
||||||
"structure %r." % (value, self._nested_structure))
|
|
||||||
|
|
||||||
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
|
|
||||||
if not structure.is_compatible_with(
|
|
||||||
type_spec.type_spec_from_value(sub_value)):
|
|
||||||
raise ValueError("Component value %r is not compatible with the nested "
|
|
||||||
"structure %r." % (sub_value, structure))
|
|
||||||
ret.extend(to_tensor_list_fn(structure, sub_value))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def _from_tensor_list(self, value):
|
|
||||||
return self.__tensors_to_value(
|
|
||||||
value, lambda struct, val: struct._from_tensor_list(val))
|
|
||||||
|
|
||||||
def _from_compatible_tensor_list(self, value):
|
|
||||||
return self.__tensors_to_value(
|
|
||||||
value, lambda struct, val: struct._from_compatible_tensor_list(val))
|
|
||||||
|
|
||||||
def __tensors_to_value(self, flat_value, from_tensor_list_fn):
|
|
||||||
if len(flat_value) != len(self._flat_tensor_specs):
|
|
||||||
raise ValueError("Expected %d flat values in NestedStructure but got %d."
|
|
||||||
% (len(self._flat_tensor_specs), len(flat_value)))
|
|
||||||
flat_ret = []
|
|
||||||
i = 0
|
|
||||||
for structure in self._flat_nested_structure:
|
|
||||||
num_flat_values = len(structure._flat_tensor_specs)
|
|
||||||
sub_value = flat_value[i:i + num_flat_values]
|
|
||||||
flat_ret.append(from_tensor_list_fn(structure, sub_value))
|
|
||||||
i += num_flat_values
|
|
||||||
|
|
||||||
return nest.pack_sequence_as(self._nested_structure, flat_ret)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_value(value):
|
|
||||||
flat_nested_structure = [
|
|
||||||
type_spec.type_spec_from_value(sub_value)
|
|
||||||
for sub_value in nest.flatten(value)
|
|
||||||
]
|
|
||||||
return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
|
|
||||||
|
|
||||||
def _to_legacy_output_types(self):
|
|
||||||
return nest.map_structure(
|
|
||||||
lambda s: s._to_legacy_output_types(), self._nested_structure)
|
|
||||||
|
|
||||||
def _to_legacy_output_shapes(self):
|
|
||||||
return nest.map_structure(
|
|
||||||
lambda s: s._to_legacy_output_shapes(), self._nested_structure)
|
|
||||||
|
|
||||||
def _to_legacy_output_classes(self):
|
|
||||||
return nest.map_structure(
|
|
||||||
lambda s: s._to_legacy_output_classes(), self._nested_structure)
|
|
||||||
|
|
||||||
def _batch(self, batch_size):
|
|
||||||
return NestedStructure(nest.map_structure(
|
|
||||||
lambda s: s._batch(batch_size), self._nested_structure))
|
|
||||||
|
|
||||||
def _unbatch(self):
|
|
||||||
return NestedStructure(nest.map_structure(
|
|
||||||
lambda s: s._unbatch(), self._nested_structure))
|
|
||||||
|
|
||||||
|
|
||||||
def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
|
def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
|
||||||
"""Returns an element constructed from the given spec and tensor list.
|
"""Returns an element constructed from the given spec and tensor list.
|
||||||
|
|
||||||
|
@ -25,7 +25,6 @@ import numpy as np
|
|||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
from tensorflow.python.data.ops import iterator_ops
|
||||||
from tensorflow.python.data.util import structure
|
|
||||||
from tensorflow.python.distribute import distribution_strategy_context
|
from tensorflow.python.distribute import distribution_strategy_context
|
||||||
from tensorflow.python.distribute import multi_worker_util
|
from tensorflow.python.distribute import multi_worker_util
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
@ -2442,10 +2441,6 @@ class Model(network.Network):
|
|||||||
# code.
|
# code.
|
||||||
if isinstance(x, dataset_ops.DatasetV2):
|
if isinstance(x, dataset_ops.DatasetV2):
|
||||||
x_shapes = dataset_ops.get_structure(x)
|
x_shapes = dataset_ops.get_structure(x)
|
||||||
# TODO(momernick): Remove this once NestedStructure goes away. Right
|
|
||||||
# now, Dataset outputs one of these instead of an actual python structure.
|
|
||||||
if isinstance(x_shapes, structure.NestedStructure):
|
|
||||||
x_shapes = x_shapes._component_specs # pylint: disable=protected-access
|
|
||||||
if isinstance(x_shapes, tuple):
|
if isinstance(x_shapes, tuple):
|
||||||
# If the output of a Dataset is a tuple, we assume it's either of the
|
# If the output of a Dataset is a tuple, we assume it's either of the
|
||||||
# form (x_data, y_data) or (x_data, y_data, sample_weights). In either
|
# form (x_data, y_data) or (x_data, y_data, sample_weights). In either
|
||||||
@ -2460,10 +2455,6 @@ class Model(network.Network):
|
|||||||
x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
|
x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
|
||||||
x_shapes = nest.map_structure(type_spec.type_spec_from_value, x)
|
x_shapes = nest.map_structure(type_spec.type_spec_from_value, x)
|
||||||
|
|
||||||
# If the inputs are still a NestedStructure, then we have a dict-input to
|
|
||||||
# this model. We can't yet validate this. (It's only relevant for feature
|
|
||||||
# columns).
|
|
||||||
if not isinstance(x_shapes, structure.NestedStructure):
|
|
||||||
flat_inputs = nest.flatten(x_shapes, expand_composites=False)
|
flat_inputs = nest.flatten(x_shapes, expand_composites=False)
|
||||||
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
|
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
|
||||||
for (a, b) in zip(flat_inputs, flat_expected_inputs):
|
for (a, b) in zip(flat_inputs, flat_expected_inputs):
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
path: "tensorflow.data.experimental.NestedStructure"
|
|
||||||
tf_class {
|
|
||||||
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
|
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
|
||||||
member {
|
|
||||||
name: "value_type"
|
|
||||||
mtype: "<type \'property\'>"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "__init__"
|
|
||||||
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "from_value"
|
|
||||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "is_compatible_with"
|
|
||||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "most_specific_compatible_type"
|
|
||||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
}
|
|
@ -28,10 +28,6 @@ tf_module {
|
|||||||
name: "MapVectorizationOptions"
|
name: "MapVectorizationOptions"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
|
||||||
name: "NestedStructure"
|
|
||||||
mtype: "<type \'type\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "OptimizationOptions"
|
name: "OptimizationOptions"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
@ -1,27 +0,0 @@
|
|||||||
path: "tensorflow.data.experimental.NestedStructure"
|
|
||||||
tf_class {
|
|
||||||
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
|
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
|
||||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
|
||||||
is_instance: "<type \'object\'>"
|
|
||||||
member {
|
|
||||||
name: "value_type"
|
|
||||||
mtype: "<type \'property\'>"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "__init__"
|
|
||||||
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "from_value"
|
|
||||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "is_compatible_with"
|
|
||||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
member_method {
|
|
||||||
name: "most_specific_compatible_type"
|
|
||||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
|
||||||
}
|
|
||||||
}
|
|
@ -28,10 +28,6 @@ tf_module {
|
|||||||
name: "MapVectorizationOptions"
|
name: "MapVectorizationOptions"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
}
|
}
|
||||||
member {
|
|
||||||
name: "NestedStructure"
|
|
||||||
mtype: "<type \'type\'>"
|
|
||||||
}
|
|
||||||
member {
|
member {
|
||||||
name: "OptimizationOptions"
|
name: "OptimizationOptions"
|
||||||
mtype: "<type \'type\'>"
|
mtype: "<type \'type\'>"
|
||||||
|
Loading…
x
Reference in New Issue
Block a user