[tf.data] Removing NestedStructure in favor of nested structure of (flat) TypeSpecs

PiperOrigin-RevId: 254862137
This commit is contained in:
Jiri Simsa 2019-06-24 16:35:30 -07:00 committed by TensorFlower Gardener
parent fb6f20cfbd
commit 1d29b5c344
7 changed files with 4 additions and 235 deletions

View File

@ -28,7 +28,6 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@DatasetStructure @@DatasetStructure
@@DistributeOptions @@DistributeOptions
@@MapVectorizationOptions @@MapVectorizationOptions
@@NestedStructure
@@OptimizationOptions @@OptimizationOptions
@@Optional @@Optional
@@OptionalStructure @@OptionalStructure
@ -137,7 +136,6 @@ from tensorflow.python.data.ops.dataset_ops import to_variant
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
from tensorflow.python.data.ops.optional_ops import Optional from tensorflow.python.data.ops.optional_ops import Optional
from tensorflow.python.data.ops.optional_ops import OptionalStructure from tensorflow.python.data.ops.optional_ops import OptionalStructure
from tensorflow.python.data.util.structure import NestedStructure
from tensorflow.python.data.util.structure import RaggedTensorStructure from tensorflow.python.data.util.structure import RaggedTensorStructure
from tensorflow.python.data.util.structure import SparseTensorStructure from tensorflow.python.data.util.structure import SparseTensorStructure
from tensorflow.python.data.util.structure import Structure from tensorflow.python.data.util.structure import Structure

View File

@ -170,164 +170,6 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
return nest.pack_sequence_as(output_classes, flat_ret) return nest.pack_sequence_as(output_classes, flat_ret)
# TODO(b/133606651) Once cl/253160704 is part of tf-nightly, remove the use
# of NestedStructure from TFF and remove this class.
@tf_export("data.experimental.NestedStructure")
class NestedStructure(type_spec.BatchableTypeSpec):
"""Represents a nested structure in which each leaf is a `TypeSpec`."""
# NOTE(edloper): This class makes extensive use of non-public TypeSpec
# methods, so we disable the protected-access lint warning once here.
# pylint: disable=protected-access
__slots__ = ["_nested_structure", "_flat_nested_structure",
"__flat_tensor_specs"]
def __init__(self, nested_structure):
self._nested_structure = nested_structure
self._flat_nested_structure = nest.flatten(nested_structure)
self.__flat_tensor_specs = []
for s in self._flat_nested_structure:
if not isinstance(s, type_spec.TypeSpec):
raise TypeError("nested_structure must be a (potentially nested) tuple "
"or dictionary of TypeSpec objects.")
self.__flat_tensor_specs.extend(s._flat_tensor_specs)
value_type = property(lambda self: type(self._nested_structure))
def _serialize(self):
return self._nested_structure
@classmethod
def _deserialize(cls, nested_structure):
return cls(nested_structure)
def most_specific_compatible_type(self, other):
if type(self) is not type(other):
raise ValueError("Incompatible types")
return self._deserialize(
nest.map_structure(lambda a, b: a.most_specific_compatible_type(b),
self._nested_structure, other._nested_structure))
def __eq__(self, other):
if not isinstance(other, NestedStructure):
return False
try:
nest.assert_same_structure(self._nested_structure,
other._nested_structure)
except (ValueError, TypeError):
return False
return (nest.flatten(self._nested_structure) ==
nest.flatten(other._nested_structure))
def __hash__(self):
return hash(tuple(nest.flatten(self._nested_structure)))
def is_compatible_with(self, other):
if not isinstance(other, NestedStructure):
return False
try:
nest.assert_same_structure(self._nested_structure,
other._nested_structure)
except (ValueError, TypeError):
return False
# pylint: disable=g-complex-comprehension
return all(
substructure.is_compatible_with(other_substructure)
for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
nest.flatten(other._nested_structure)))
_component_specs = property(lambda self: self._nested_structure)
_flat_tensor_specs = property(lambda self: self.__flat_tensor_specs)
def _to_components(self, value):
return nest.map_structure_up_to(
self._nested_structure, lambda t, v: t._to_components(v),
self._nested_structure, value)
def _from_components(self, value):
return nest.map_structure_up_to(
self._nested_structure, lambda t, v: t._from_components(v),
self._nested_structure, value)
def _to_tensor_list(self, value):
return self.__value_to_tensors(
value, lambda struct, val: struct._to_tensor_list(val))
def _to_batched_tensor_list(self, value):
return self.__value_to_tensors(
value, lambda struct, val: struct._to_batched_tensor_list(val))
def __value_to_tensors(self, value, to_tensor_list_fn):
ret = []
try:
flat_value = nest.flatten_up_to(self._nested_structure, value)
except (ValueError, TypeError):
raise ValueError("The value %r is not compatible with the nested "
"structure %r." % (value, self._nested_structure))
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
if not structure.is_compatible_with(
type_spec.type_spec_from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(to_tensor_list_fn(structure, sub_value))
return ret
def _from_tensor_list(self, value):
return self.__tensors_to_value(
value, lambda struct, val: struct._from_tensor_list(val))
def _from_compatible_tensor_list(self, value):
return self.__tensors_to_value(
value, lambda struct, val: struct._from_compatible_tensor_list(val))
def __tensors_to_value(self, flat_value, from_tensor_list_fn):
if len(flat_value) != len(self._flat_tensor_specs):
raise ValueError("Expected %d flat values in NestedStructure but got %d."
% (len(self._flat_tensor_specs), len(flat_value)))
flat_ret = []
i = 0
for structure in self._flat_nested_structure:
num_flat_values = len(structure._flat_tensor_specs)
sub_value = flat_value[i:i + num_flat_values]
flat_ret.append(from_tensor_list_fn(structure, sub_value))
i += num_flat_values
return nest.pack_sequence_as(self._nested_structure, flat_ret)
@staticmethod
def from_value(value):
flat_nested_structure = [
type_spec.type_spec_from_value(sub_value)
for sub_value in nest.flatten(value)
]
return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
def _to_legacy_output_types(self):
return nest.map_structure(
lambda s: s._to_legacy_output_types(), self._nested_structure)
def _to_legacy_output_shapes(self):
return nest.map_structure(
lambda s: s._to_legacy_output_shapes(), self._nested_structure)
def _to_legacy_output_classes(self):
return nest.map_structure(
lambda s: s._to_legacy_output_classes(), self._nested_structure)
def _batch(self, batch_size):
return NestedStructure(nest.map_structure(
lambda s: s._batch(batch_size), self._nested_structure))
def _unbatch(self):
return NestedStructure(nest.map_structure(
lambda s: s._unbatch(), self._nested_structure))
def _from_tensor_list_helper(decode_fn, element_spec, tensor_list): def _from_tensor_list_helper(decode_fn, element_spec, tensor_list):
"""Returns an element constructed from the given spec and tensor list. """Returns an element constructed from the given spec and tensor list.

View File

@ -25,7 +25,6 @@ import numpy as np
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import structure
from tensorflow.python.distribute import distribution_strategy_context from tensorflow.python.distribute import distribution_strategy_context
from tensorflow.python.distribute import multi_worker_util from tensorflow.python.distribute import multi_worker_util
from tensorflow.python.eager import context from tensorflow.python.eager import context
@ -2442,10 +2441,6 @@ class Model(network.Network):
# code. # code.
if isinstance(x, dataset_ops.DatasetV2): if isinstance(x, dataset_ops.DatasetV2):
x_shapes = dataset_ops.get_structure(x) x_shapes = dataset_ops.get_structure(x)
# TODO(momernick): Remove this once NestedStructure goes away. Right
# now, Dataset outputs one of these instead of an actual python structure.
if isinstance(x_shapes, structure.NestedStructure):
x_shapes = x_shapes._component_specs # pylint: disable=protected-access
if isinstance(x_shapes, tuple): if isinstance(x_shapes, tuple):
# If the output of a Dataset is a tuple, we assume it's either of the # If the output of a Dataset is a tuple, we assume it's either of the
# form (x_data, y_data) or (x_data, y_data, sample_weights). In either # form (x_data, y_data) or (x_data, y_data, sample_weights). In either
@ -2460,10 +2455,6 @@ class Model(network.Network):
x = nest.pack_sequence_as(x, converted_x, expand_composites=False) x = nest.pack_sequence_as(x, converted_x, expand_composites=False)
x_shapes = nest.map_structure(type_spec.type_spec_from_value, x) x_shapes = nest.map_structure(type_spec.type_spec_from_value, x)
# If the inputs are still a NestedStructure, then we have a dict-input to
# this model. We can't yet validate this. (It's only relevant for feature
# columns).
if not isinstance(x_shapes, structure.NestedStructure):
flat_inputs = nest.flatten(x_shapes, expand_composites=False) flat_inputs = nest.flatten(x_shapes, expand_composites=False)
flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False) flat_expected_inputs = nest.flatten(self.inputs, expand_composites=False)
for (a, b) in zip(flat_inputs, flat_expected_inputs): for (a, b) in zip(flat_inputs, flat_expected_inputs):

View File

@ -1,27 +0,0 @@
path: "tensorflow.data.experimental.NestedStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -28,10 +28,6 @@ tf_module {
name: "MapVectorizationOptions" name: "MapVectorizationOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "NestedStructure"
mtype: "<type \'type\'>"
}
member { member {
name: "OptimizationOptions" name: "OptimizationOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"

View File

@ -1,27 +0,0 @@
path: "tensorflow.data.experimental.NestedStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -28,10 +28,6 @@ tf_module {
name: "MapVectorizationOptions" name: "MapVectorizationOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "NestedStructure"
mtype: "<type \'type\'>"
}
member { member {
name: "OptimizationOptions" name: "OptimizationOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"