Merge pull request #29670 from goldiegadde/ggadde-cp5

Cherrypick important fixes to r2.0 branch.
This commit is contained in:
Goldie Gadde 2019-06-12 16:37:56 -07:00 committed by GitHub
commit d08e899087
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 1181 additions and 1177 deletions

View File

@ -140,6 +140,8 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
for (const auto& attr_name : function_attribute_names) {
string function_name = node_def->attr().at(attr_name).func().name();
// Skip the function if its already optimized by function optimizer.
if (::absl::StrContains(function_name, "_specialized_for_")) continue;
std::vector<string> equiv_func_names;
TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
function_name, &equiv_func_names));
@ -153,7 +155,8 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
}
}
if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
if (lib_info_->GetApiInfo(node_def->op()) != nullptr &&
!::absl::StrContains(node_def->op(), "_specialized_for_")) {
std::vector<string> equiv_func_names;
TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
node_def->op(), &equiv_func_names));

View File

@ -328,6 +328,33 @@ class FromSavedModelTest(TestModels):
self.assertIn('This converter can only convert a single ConcreteFunction',
str(error.exception))
def testKerasSequentialModel(self):
"""Test a simple sequential tf.Keras model."""
self.skipTest('b/134660903')
input_data = constant_op.constant(1., shape=[1, 1])
x = np.array([[1.], [2.]])
y = np.array([[2.], [4.]])
model = keras.models.Sequential([
keras.layers.Dropout(0.2),
keras.layers.Dense(1),
])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=1)
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
save(model, save_dir)
# Convert model and ensure model is not None.
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
tflite_model = converter.convert()
# Check values from converted model.
expected_value = model.predict(input_data)
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
self.assertEqual(expected_value, actual_value)
class FromKerasModelTest(TestModels):
@ -337,11 +364,13 @@ class FromKerasModelTest(TestModels):
input_data = constant_op.constant(1., shape=[1, 1])
# Create a simple Keras model.
x = [-1, 0, 1, 2, 3, 4]
y = [-3, -1, 1, 3, 5, 7]
x = np.array([[1.], [2.]])
y = np.array([[2.], [4.]])
model = keras.models.Sequential(
[keras.layers.Dense(units=1, input_shape=[1])])
model = keras.models.Sequential([
keras.layers.Dropout(0.2),
keras.layers.Dense(units=1, input_shape=[1])
])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=1)

View File

@ -958,6 +958,7 @@ py_library(
":tensor_shape",
":tf2",
":traceable_stack",
":type_spec",
":util",
":versions",
"//tensorflow/core:protos_all_py",
@ -1138,6 +1139,7 @@ py_library(
":framework_ops",
":tensor_like",
":tensor_util",
":type_spec",
],
)
@ -1244,6 +1246,7 @@ py_library(
":common_shapes",
":dtypes",
":tensor_shape",
":type_spec",
":util",
"//third_party/py/numpy",
],

View File

@ -23,6 +23,7 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.framework import ops
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util.tf_export import tf_export
@ -40,7 +41,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
# Compute initial values for the state classes, shapes and types based on
# the initial state. The shapes may be refined by running `tf_scan_func` one
# or more times below.
self._state_structure = structure.Structure.from_value(self._initial_state)
self._state_structure = type_spec.type_spec_from_value(self._initial_state)
# Iteratively rerun the scan function until reaching a fixed point on
# `self._state_shapes`.

View File

@ -38,6 +38,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@ -290,7 +291,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
)
def testDatasetStructure(self, tf_value_fn, expected_element_structure):
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
dataset_structure = structure.Structure.from_value(dataset)
dataset_structure = type_spec.type_spec_from_value(dataset)
self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)
# TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing

View File

@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
@ -279,7 +280,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertTrue(
opt.value_structure.is_compatible_with(expected_value_structure))
opt_structure = structure.Structure.from_value(opt)
opt_structure = type_spec.type_spec_from_value(opt)
self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
self.assertTrue(opt_structure.is_compatible_with(opt_structure))
self.assertTrue(opt_structure._value_structure.is_compatible_with(
@ -289,7 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
# All OptionalStructure objects are not compatible with a non-optional
# value.
non_optional_structure = structure.Structure.from_value(
non_optional_structure = type_spec.type_spec_from_value(
constant_op.constant(42.0))
self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))
@ -340,7 +341,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertIsInstance(next_elem, optional_ops.Optional)
self.assertTrue(
next_elem.value_structure.is_compatible_with(
structure.Structure.from_value(tf_value_fn())))
type_spec.type_spec_from_value(tf_value_fn())))
self.assertTrue(next_elem.has_value())
self._assertElementValueEqual(np_value, next_elem.get_value())
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
@ -356,7 +357,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertIsInstance(next_elem, optional_ops.Optional)
self.assertTrue(
next_elem.value_structure.is_compatible_with(
structure.Structure.from_value(tf_value_fn())))
type_spec.type_spec_from_value(tf_value_fn())))
# Before initializing the iterator, evaluating the optional fails with
# a FailedPreconditionError. This is only relevant in graph mode.
elem_has_value_t = next_elem.has_value()

View File

@ -54,6 +54,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_dataset_ops
@ -1391,7 +1392,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
with ops.name_scope("initial_state"):
initial_state = structure_lib.normalize_tensors(initial_state)
state_structure = structure_lib.Structure.from_value(initial_state)
state_structure = type_spec.type_spec_from_value(initial_state)
# Iteratively rerun the reduce function until reaching a fixed point on
# `state_structure`.
@ -2264,7 +2265,7 @@ class TensorDataset(DatasetSource):
def __init__(self, tensors):
"""See `Dataset.from_tensors()` for details."""
tensors = structure_lib.normalize_tensors(tensors)
self._structure = structure_lib.Structure.from_value(tensors)
self._structure = type_spec.type_spec_from_value(tensors)
self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access
variant_tensor = gen_dataset_ops.tensor_dataset(
@ -2284,7 +2285,7 @@ class TensorSliceDataset(DatasetSource):
with ops.name_scope("tensors"):
tensors = structure_lib.normalize_tensors(tensors)
batched_structure = structure_lib.Structure.from_value(tensors)
batched_structure = type_spec.type_spec_from_value(tensors)
# pylint: disable=protected-access
self._tensors = batched_structure._to_batched_tensor_list(tensors)
self._structure = batched_structure._unbatch()
@ -2377,51 +2378,33 @@ def to_variant(dataset):
return dataset._variant_tensor # pylint: disable=protected-access
# TODO(b/133606651) Rename this class to DatasetSpec
@tf_export("data.experimental.DatasetStructure")
class DatasetStructure(structure_lib.Structure):
"""Represents a `Dataset` of structured values."""
class DatasetStructure(type_spec.TypeSpec):
"""Type specification for `tf.data.Dataset`."""
def __init__(self, element_structure):
self._element_structure = element_structure
__slots__ = ["_element_structure"]
def __eq__(self, other):
# pylint: disable=protected-access
return (isinstance(other, DatasetStructure) and
self._element_structure == other._element_structure)
def __hash__(self):
return hash(self._element_structure)
def __init__(self, element_spec):
self._element_structure = element_spec
@property
def _flat_shapes(self):
return [tensor_shape.scalar()]
def value_type(self):
return _VariantDataset
def _serialize(self):
return (self._element_structure,)
@property
def _flat_types(self):
return [dtypes.variant]
def _component_specs(self):
return tensor_spec.TensorSpec([], dtypes.variant)
def is_compatible_with(self, other):
def _to_components(self, value):
return value._variant_tensor # pylint: disable=protected-access
def _from_components(self, components):
# pylint: disable=protected-access
return (isinstance(other, DatasetStructure) and
self._element_structure.is_compatible_with(
other._element_structure))
def _to_tensor_list(self, value):
return [value._variant_tensor] # pylint: disable=protected-access
def _to_batched_tensor_list(self, value):
raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
raise ValueError(
"DatasetStructure corresponds to a single tf.variant scalar.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
# pylint: disable=protected-access
return _VariantDataset(flat_value[0], self._element_structure)
return _VariantDataset(components, self._element_structure)
@staticmethod
def from_value(value):
@ -2436,17 +2419,11 @@ class DatasetStructure(structure_lib.Structure):
def _to_legacy_output_classes(self):
return self
def _batch(self, batch_size):
raise NotImplementedError("Batching for `tf.data.Dataset` objects.")
def _unbatch(self):
raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
# pylint: disable=protected-access
structure_lib.Structure._register_custom_converter(DatasetV2,
DatasetStructure.from_value)
# pylint: enable=protected-access
# TODO(b/133606651) Delete this registration when CompositeTensor is updated
# to define a _type_spec field (since registration will be automatic).
type_spec.register_type_spec_from_value_converter(
DatasetV2, DatasetStructure.from_value, allow_subclass=True)
class StructuredFunctionWrapper(object):
@ -2565,7 +2542,7 @@ class StructuredFunctionWrapper(object):
ret = tuple(ret)
try:
self._output_structure = structure_lib.Structure.from_value(ret)
self._output_structure = type_spec.type_spec_from_value(ret)
except (ValueError, TypeError):
raise TypeError("Unsupported return value from function passed to "
"%s: %s." % (transformation_name, ret))
@ -2601,12 +2578,7 @@ class StructuredFunctionWrapper(object):
# TODO(b/124254153): Enable autograph once the overhead is low enough.
# TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
@eager_function.defun_with_attributes(
input_signature=[
tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension
for input_shape, input_type in zip(
self._input_structure._flat_shapes,
self._input_structure._flat_types)
],
input_signature=self._input_structure._flat_tensor_specs,
autograph=False,
attributes=defun_kwargs)
def wrapper_fn(*args): # pylint: disable=missing-docstring
@ -2697,7 +2669,7 @@ class _GeneratorDataset(DatasetSource):
"""
self._init_args = init_args
self._init_structure = structure_lib.Structure.from_value(init_args)
self._init_structure = type_spec.type_spec_from_value(init_args)
self._init_func = StructuredFunctionWrapper(
init_func,

View File

@ -21,11 +21,12 @@ import abc
import six
from tensorflow.python.data.util import structure
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.util.tf_export import tf_export
@ -93,7 +94,7 @@ class Optional(composite_tensor.CompositeTensor):
"""
with ops.name_scope("optional") as scope:
with ops.name_scope("value"):
value_structure = structure.Structure.from_value(value)
value_structure = type_spec.type_spec_from_value(value)
encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access
return _OptionalImpl(
@ -167,49 +168,31 @@ class _OptionalImpl(Optional):
return hasattr(self._variant_tensor, "graph")
# TODO(b/133606651) Rename this class to OptionalSpec
@tf_export("data.experimental.OptionalStructure")
class OptionalStructure(structure.Structure):
class OptionalStructure(type_spec.TypeSpec):
"""Represents an optional potentially containing a structured value."""
__slots__ = ["_value_structure"]
def __init__(self, value_structure):
self._value_structure = value_structure
def __eq__(self, other):
# pylint: disable=protected-access
return (isinstance(other, OptionalStructure) and
self._value_structure == other._value_structure)
@property
def value_type(self):
return _OptionalImpl
def __hash__(self):
return hash(self._value_structure)
def _serialize(self):
return (self._value_structure,)
@property
def _flat_shapes(self):
return [tensor_shape.scalar()]
def _component_specs(self):
return [tensor_spec.TensorSpec((), dtypes.variant)]
@property
def _flat_types(self):
return [dtypes.variant]
def is_compatible_with(self, other):
# pylint: disable=protected-access
return (isinstance(other, OptionalStructure) and
self._value_structure.is_compatible_with(other._value_structure))
def _to_tensor_list(self, value):
def _to_components(self, value):
return [value._variant_tensor] # pylint: disable=protected-access
def _to_batched_tensor_list(self, value):
raise NotImplementedError(
"Unbatching for `tf.data.experimental.Optional` objects.")
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
raise ValueError(
"OptionalStructure corresponds to a single tf.variant scalar.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
def _from_components(self, flat_value):
# pylint: disable=protected-access
return _OptionalImpl(flat_value[0], self._value_structure)
@ -226,16 +209,8 @@ class OptionalStructure(structure.Structure):
def _to_legacy_output_classes(self):
return self
def _batch(self, batch_size):
raise NotImplementedError(
"Batching for `tf.data.experimental.Optional` objects.")
def _unbatch(self):
raise NotImplementedError(
"Unbatching for `tf.data.experimental.Optional` objects.")
# pylint: disable=protected-access
structure.Structure._register_custom_converter(Optional,
OptionalStructure.from_value)
# pylint: enable=protected-access
type_spec.register_type_spec_from_value_converter(Optional,
OptionalStructure.from_value)
type_spec.register_type_spec_from_value_converter(_OptionalImpl,
OptionalStructure.from_value)

View File

@ -17,263 +17,55 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import six
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.util.tf_export import tf_export
_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
@tf_export("data.experimental.Structure")
@six.add_metaclass(abc.ABCMeta)
class Structure(object):
"""Represents structural information, such as type and shape, about a value.
# Define backwards-compatiblity wrappers for using TypeSpec and its subclasses
# to replace Structure and its subclasses. Note that the constructor argument
# order is different in many cases -- in particular, TypeSpec follows TensorSpec
# and uses the order (shape, dtype); but most Structure subclasses use the
# order (dtype, shape).
#
# TODO(b/133606651) Update tf.data to use TypeSpec directly, and then remove
# these compatibility wrappers.
A `Structure` generalizes the `tf.Tensor.dtype` and `tf.Tensor.shape`
properties, so that we can define generic containers of objects including:
* `tf.Tensor`
* `tf.SparseTensor`
* Nested structures of the above.
Structure = type_spec.TypeSpec
TODO(b/110122868): In the future, a single `Structure` will replace the
`tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`,
and `tf.data.Dataset.output_classes`, and similar properties and arguments in
the `tf.compat.v1.data.Iterator` and `Optional` classes.
"""
@abc.abstractmethod
def __eq__(self, other):
"""Returns the this structure and the input structure are equal.
# pylint: disable=invalid-name
Args:
other: the structure to use for equality check
Returns:
`True` if this and the input structure are equal and `False` otherwise.
"""
raise NotImplementedError("Structure.__eq__()")
@tf_export("data.experimental.TensorStructure")
def TensorStructure(dtype, shape):
return tensor_spec.TensorSpec(shape, dtype)
def __ne__(self, other):
return not self == other
@abc.abstractmethod
def __hash__(self):
"""Returns the hash of this structure.
@tf_export("data.experimental.SparseTensorStructure")
def SparseTensorStructure(dtype, shape):
return sparse_tensor_lib.SparseTensorSpec(shape, dtype)
Returns:
The hash of this structure.
"""
raise NotImplementedError("Structure.__hash__()")
@abc.abstractproperty
def _flat_shapes(self):
"""A list of shapes matching the shapes of `self._to_tensor_list()`.
@tf_export("data.experimental.TensorArrayStructure")
def TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape):
return tensor_array_ops.TensorArraySpec(element_shape, dtype,
dynamic_size, infer_shape)
Returns:
A list of `tf.TensorShape` objects.
"""
raise NotImplementedError("Structure._flat_shapes")
@abc.abstractproperty
def _flat_types(self):
"""A list of types matching the types of `self._to_tensor_list()`.
Returns:
A list of `tf.DType` objects.
"""
raise NotImplementedError("Structure._flat_shapes")
@abc.abstractmethod
def is_compatible_with(self, other):
"""Returns `True` if `other` is compatible with this structure.
A structure `t` is a "subtype" of `s` if:
* `s` and `t` are instances of the same `Structure` subclass.
* The nested structures (if any) of `s` and `t` are the same, according to
`tf.nest.assert_same_structure`, and each nested
structure of `t` is a "subtype" of the corresponding nested structure of
`s`.
* Any `tf.DType` components of `t` are the same as the corresponding
components in `s`.
* Any `tf.TensorShape` components of `t` are compatible with the
corresponding components in `s`, according to
`tf.TensorShape.is_compatible_with`.
Args:
other: A `Structure`.
Returns:
`True` if `other` is a subtype of this structure, otherwise `False`.
"""
raise NotImplementedError("Structure.is_compatible_with()")
@abc.abstractmethod
def _to_tensor_list(self, value):
"""Returns a flat list of `tf.Tensor` representing `value`.
This method can be used, along with `self._flat_shapes` and
`self._flat_types` to represent structured values in lower level APIs
(such as plain TensorFlow operations) that do not understand structure.
Requires: `self.is_compatible_with(Structure.from_value(value))`.
Args:
value: A value with compatible structure.
Returns:
A flat list of `tf.Tensor` representing `value`.
"""
raise NotImplementedError("Structure._to_tensor_list()")
@abc.abstractmethod
def _to_batched_tensor_list(self, value):
"""Returns a flat list of rank >= 1 `tf.Tensor` representing `value`.
This method can be used, along with `self._flat_shapes` and
`self._flat_types` to represent structured values in lower level APIs
(such as plain TensorFlow operations) that do not understand structure,
*and* that require that the plain tensors have a rank of at least one
(e.g. for the purpose of slicing the tensors).
Requires: `self.is_compatible_with(Structure.from_value(value))`.
Args:
value: A value with compatible structure.
Returns:
A flat list of `tf.Tensor` representing `value`.
"""
raise NotImplementedError("Structure._to_batched_tensor_list()")
@abc.abstractmethod
def _from_tensor_list(self, flat_value):
"""Builds a flat list of `tf.Tensor` into a value matching this structure.
Args:
flat_value: A list of `tf.Tensor` with compatible flat structure.
Returns:
A structured object matching this structure.
Raises:
ValueError: If the shapes and types of the tensors in `flat_value` are not
compatible with `self._flat_shapes` and `self._flat_types` respectively.
"""
raise NotImplementedError("Structure._from_tensor_list()")
def _from_compatible_tensor_list(self, flat_value):
"""A version of `_from_tensor_list()` that may avoid performing checks.
NOTE: This method should be used to avoid checks for performance reasons,
when the validity of `flat_value` has been validated by other means.
The shapes and types of the tensors in `flat_value` must be compatible with
`self._flat_shapes` and `self._flat_types` respectively. The behavior is
undefined if this requirement is not met.
Args:
flat_value: A list of `tf.Tensor` with compatible flat structure.
Returns:
A structured object matching this structure.
"""
return self._from_tensor_list(flat_value)
@abc.abstractmethod
def _batch(self, batch_size):
"""Returns a structure representing a batch of objects with this structure.
Args:
batch_size: An `int` representing the number of elements in a batch,
or `None` if the batch size may vary.
Returns:
A `Structure` representing a batch of objects with this structure.
"""
raise NotImplementedError("Structure._batch()")
@abc.abstractmethod
def _unbatch(self):
raise NotImplementedError("Structure._unbatch()")
@staticmethod
def from_value(value):
"""Returns a `Structure` that represents the given `value`.
Args:
value: A potentially structured value.
Returns:
A `Structure` that is compatible with `value`.
Raises:
TypeError: If a structure cannot be built for `value`, because its type
or one of its component types is not supported.
"""
# TODO(b/110122868): Add support for custom types and Dataset to this
# method.
if isinstance(
value,
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
return SparseTensorStructure.from_value(value)
elif isinstance(value, tensor_array_ops.TensorArray):
return TensorArrayStructure.from_value(value)
elif isinstance(
value,
(ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
return RaggedTensorStructure.from_value(value)
elif isinstance(value, (tuple, dict)):
return NestedStructure.from_value(value)
else:
for converter_type, converter_fn in (
_STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
if isinstance(value, converter_type):
return converter_fn(value)
try:
tensor = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError("Could not build a structure for %r" % value)
return TensorStructure.from_value(tensor)
@staticmethod
def _register_custom_converter(type_object, converter_fn):
"""Registers `converter_fn` for converting values of the given type.
Args:
type_object: A Python `type` object representing the type of values
accepted by `converter_fn`.
converter_fn: A function that takes one argument (an instance of the
type represented by `type_object`) and returns a `Structure`.
"""
_STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
@abc.abstractmethod
def _to_legacy_output_types(self):
raise NotImplementedError("Structure._to_legacy_output_types()")
@abc.abstractmethod
def _to_legacy_output_shapes(self):
raise NotImplementedError("Structure._to_legacy_output_shapes()")
@abc.abstractmethod
def _to_legacy_output_classes(self):
raise NotImplementedError("Structure._to_legacy_output_classes()")
@tf_export("data.experimental.RaggedTensorStructure")
def RaggedTensorStructure(dtype, shape, ragged_rank):
return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank)
def normalize_tensors(tensors):
@ -340,7 +132,7 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
flat_ret = []
for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
flat_classes):
if isinstance(flat_class, Structure):
if isinstance(flat_class, type_spec.TypeSpec):
flat_ret.append(flat_class)
elif issubclass(flat_class, sparse_tensor_lib.SparseTensor):
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
@ -361,88 +153,102 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
"Could not build a structure for output class %r" % (flat_class,))
ret = nest.pack_sequence_as(output_classes, flat_ret)
if isinstance(ret, Structure):
if isinstance(ret, type_spec.TypeSpec):
return ret
else:
return NestedStructure(ret)
# NOTE(mrry): The following classes make extensive use of non-public methods of
# their base class, so we disable the protected-access lint warning once here.
# pylint: disable=protected-access
# TODO(b/133606651) Update the tf.data code to use nests of TypeSpec rather
# than NestedStructure; and then delete this class.
@tf_export("data.experimental.NestedStructure")
class NestedStructure(Structure):
"""Represents a nested structure in which each leaf is a `Structure`."""
class NestedStructure(type_spec.BatchableTypeSpec):
"""Represents a nested structure in which each leaf is a `TypeSpec`."""
# NOTE(edloper): This class makes extensive use of non-public TypeSpec
# methods, so we disable the protected-access lint warning once here.
# pylint: disable=protected-access
__slots__ = ["_nested_structure", "_flat_nested_structure",
"__flat_tensor_specs"]
def __init__(self, nested_structure):
self._nested_structure = nested_structure
self._flat_nested_structure = nest.flatten(nested_structure)
self._flat_shapes_list = []
self._flat_types_list = []
for s in nest.flatten(nested_structure):
if not isinstance(s, Structure):
self.__flat_tensor_specs = []
for s in self._flat_nested_structure:
if not isinstance(s, type_spec.TypeSpec):
raise TypeError("nested_structure must be a (potentially nested) tuple "
"or dictionary of Structure objects.")
self._flat_shapes_list.extend(s._flat_shapes)
self._flat_types_list.extend(s._flat_types)
"or dictionary of TypeSpec objects.")
self.__flat_tensor_specs.extend(s._flat_tensor_specs)
value_type = property(lambda self: type(self._nested_structure))
def _serialize(self):
return self._nested_structure
@classmethod
def _deserialize(cls, nested_structure):
return cls(nested_structure)
def most_specific_compatible_type(self, other):
if type(self) is not type(other):
raise ValueError("Incompatible types")
return nest.map_structure(lambda a, b: a.most_specific_compatible_type(b),
self._nested_structure, other._nested_structure)
def __eq__(self, other):
if not isinstance(other, NestedStructure):
return False
try:
# pylint: disable=protected-access
nest.assert_same_structure(self._nested_structure,
other._nested_structure)
except (ValueError, TypeError):
return False
return nest.flatten(self._nested_structure) == nest.flatten(
other._nested_structure)
return (nest.flatten(self._nested_structure) ==
nest.flatten(other._nested_structure))
def __hash__(self):
return hash(tuple(nest.flatten(self._nested_structure)))
@property
def _flat_shapes(self):
return self._flat_shapes_list
@property
def _flat_types(self):
return self._flat_types_list
def is_compatible_with(self, other):
if not isinstance(other, NestedStructure):
return False
try:
# pylint: disable=protected-access
nest.assert_same_structure(self._nested_structure,
other._nested_structure)
except (ValueError, TypeError):
return False
# pylint: disable=g-complex-comprehension
return all(
substructure.is_compatible_with(other_substructure)
for substructure, other_substructure in zip(
nest.flatten(self._nested_structure),
nest.flatten(other._nested_structure)))
_component_specs = property(lambda self: self._nested_structure)
_flat_tensor_specs = property(lambda self: self.__flat_tensor_specs)
def _to_components(self, value):
return nest.map_structure_up_to(
self._nested_structure, lambda t, v: t._to_components(v),
self._nested_structure, value)
def _from_components(self, value):
return nest.map_structure_up_to(
self._nested_structure, lambda t, v: t._from_components(v),
self._nested_structure, value)
def _to_tensor_list(self, value):
ret = []
try:
flat_value = nest.flatten_up_to(self._nested_structure, value)
except (ValueError, TypeError):
raise ValueError("The value %r is not compatible with the nested "
"structure %r." % (value, self._nested_structure))
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
if not structure.is_compatible_with(Structure.from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(structure._to_tensor_list(sub_value))
return ret
return self.__value_to_tensors(
value, lambda struct, val: struct._to_tensor_list(val))
def _to_batched_tensor_list(self, value):
return self.__value_to_tensors(
value, lambda struct, val: struct._to_batched_tensor_list(val))
def __value_to_tensors(self, value, to_tensor_list_fn):
ret = []
try:
@ -452,34 +258,31 @@ class NestedStructure(Structure):
"structure %r." % (value, self._nested_structure))
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
if not structure.is_compatible_with(Structure.from_value(sub_value)):
if not structure.is_compatible_with(
type_spec.type_spec_from_value(sub_value)):
raise ValueError("Component value %r is not compatible with the nested "
"structure %r." % (sub_value, structure))
ret.extend(structure._to_batched_tensor_list(sub_value))
ret.extend(to_tensor_list_fn(structure, sub_value))
return ret
def _from_tensor_list(self, flat_value):
if len(flat_value) != len(self._flat_types):
def _from_tensor_list(self, value):
return self.__tensors_to_value(
value, lambda struct, val: struct._from_tensor_list(val))
def _from_compatible_tensor_list(self, value):
return self.__tensors_to_value(
value, lambda struct, val: struct._from_compatible_tensor_list(val))
def __tensors_to_value(self, flat_value, from_tensor_list_fn):
if len(flat_value) != len(self._flat_tensor_specs):
raise ValueError("Expected %d flat values in NestedStructure but got %d."
% (len(self._flat_types), len(flat_value)))
% (len(self._flat_tensor_specs), len(flat_value)))
flat_ret = []
i = 0
for structure in self._flat_nested_structure:
num_flat_values = len(structure._flat_types)
num_flat_values = len(structure._flat_tensor_specs)
sub_value = flat_value[i:i + num_flat_values]
flat_ret.append(structure._from_tensor_list(sub_value))
i += num_flat_values
return nest.pack_sequence_as(self._nested_structure, flat_ret)
def _from_compatible_tensor_list(self, flat_value):
flat_ret = []
i = 0
for structure in self._flat_nested_structure:
num_flat_values = len(structure._flat_types)
sub_value = flat_value[i:i + num_flat_values]
flat_ret.append(structure._from_compatible_tensor_list(sub_value))
flat_ret.append(from_tensor_list_fn(structure, sub_value))
i += num_flat_values
return nest.pack_sequence_as(self._nested_structure, flat_ret)
@ -487,7 +290,8 @@ class NestedStructure(Structure):
@staticmethod
def from_value(value):
flat_nested_structure = [
Structure.from_value(sub_value) for sub_value in nest.flatten(value)
type_spec.type_spec_from_value(sub_value)
for sub_value in nest.flatten(value)
]
return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
@ -512,362 +316,14 @@ class NestedStructure(Structure):
lambda s: s._unbatch(), self._nested_structure))
@tf_export("data.experimental.TensorStructure")
class TensorStructure(Structure):
"""Represents structural information about a `tf.Tensor`."""
type_spec.register_type_spec_from_value_converter(
tuple, NestedStructure.from_value, allow_subclass=True)
type_spec.register_type_spec_from_value_converter(
dict, NestedStructure.from_value, allow_subclass=True)
def __init__(self, dtype, shape):
self._dtype = dtypes.as_dtype(dtype)
self._shape = tensor_shape.as_shape(shape)
def __eq__(self, other):
return (isinstance(other, TensorStructure) and tensor_spec.TensorSpec(
self._shape, self._dtype) == tensor_spec.TensorSpec(
other._shape, other._dtype))
def __hash__(self):
return hash(tensor_spec.TensorSpec(self._shape, self._dtype))
@property
def _flat_shapes(self):
return [self._shape]
@property
def _flat_types(self):
return [self._dtype]
def is_compatible_with(self, other):
return (isinstance(other, TensorStructure) and
self._dtype.is_compatible_with(other._dtype) and
self._shape.is_compatible_with(other._shape))
def _to_tensor_list(self, value):
if not self.is_compatible_with(Structure.from_value(value)):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return [value]
def _to_batched_tensor_list(self, value):
if self._shape.merge_with(value.shape).ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return [value]
def _from_tensor_list(self, flat_value):
if len(flat_value) != 1:
raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
if not self.is_compatible_with(Structure.from_value(flat_value[0])):
raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
"%s." % (flat_value[0], self._dtype, self._shape))
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
# op here and return that, instead of mutating the input's shape using
# `Tensor.set_shape()`. However, that would add extra ops on the arguments
# of each `tf.data` function, which could impact performance. When this
# bug is resolved, we should be able to add the `ensure_shape()` ops and
# optimize them away using contextual shape information.
flat_value[0].set_shape(self._shape)
return flat_value[0]
@staticmethod
def from_value(value):
return TensorStructure(value.dtype, value.shape)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._shape
def _to_legacy_output_classes(self):
return ops.Tensor
def _batch(self, batch_size):
return TensorStructure(
self._dtype,
tensor_shape.TensorShape([batch_size]).concatenate(self._shape))
def _unbatch(self):
if self._shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return TensorStructure(self._dtype, self._shape[1:])
@tf_export("data.experimental.SparseTensorStructure")
class SparseTensorStructure(Structure):
"""Represents structural information about a `tf.SparseTensor`."""
def __init__(self, dtype, dense_shape):
self._dtype = dtypes.as_dtype(dtype)
self._dense_shape = tensor_shape.as_shape(dense_shape)
def __eq__(self, other):
return (isinstance(other, SparseTensorStructure) and tensor_spec.TensorSpec(
self._dense_shape, self._dtype) == tensor_spec.TensorSpec(
other._dense_shape, other._dtype))
def __hash__(self):
return hash(tensor_spec.TensorSpec(self._dense_shape, self._dtype))
@property
def _flat_shapes(self):
# NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
# but a `SparseTensorStructure` can also represent a batch of boxed
# `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.),
# so the flat shape must be unknown.
return [tensor_shape.unknown_shape(None)]
@property
def _flat_types(self):
return [dtypes.variant]
def is_compatible_with(self, other):
return (isinstance(other, SparseTensorStructure) and
self._dtype.is_compatible_with(other._dtype) and
self._dense_shape.is_compatible_with(other._dense_shape))
def _to_tensor_list(self, value):
return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
def _to_batched_tensor_list(self, value):
if self._dense_shape.merge_with(
tensor_util.constant_value_as_shape(value.dense_shape)).ndims == 0:
raise ValueError(
"Unbatching a sparse tensor is only supported for rank >= 1")
return [sparse_ops.serialize_many_sparse(value, out_type=dtypes.variant)]
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))):
raise ValueError("SparseTensorStructure corresponds to a single "
"tf.variant vector of length 3.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
ret = sparse_ops.deserialize_sparse(
flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
ret.indices.set_shape([None, self._dense_shape.ndims])
ret.dense_shape.set_shape([self._dense_shape.ndims])
return ret
@staticmethod
def from_value(value):
sparse_tensor = sparse_tensor_lib.SparseTensor.from_value(value)
return SparseTensorStructure(
sparse_tensor.dtype,
tensor_util.constant_value_as_shape(sparse_tensor.dense_shape))
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._dense_shape
def _to_legacy_output_classes(self):
return sparse_tensor_lib.SparseTensor
def _batch(self, batch_size):
return SparseTensorStructure(
self._dtype,
tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape))
def _unbatch(self):
if self._dense_shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return SparseTensorStructure(self._dtype, self._dense_shape[1:])
@tf_export("data.experimental.TensorArrayStructure")
class TensorArrayStructure(Structure):
"""Represents structural information about a `tf.TensorArray`."""
def __init__(self, dtype, element_shape, dynamic_size, infer_shape):
self._dtype = dtypes.as_dtype(dtype)
self._element_shape = tensor_shape.as_shape(element_shape)
self._dynamic_size = dynamic_size
self._infer_shape = infer_shape
def __eq__(self, other):
return (isinstance(other, TensorArrayStructure) and tensor_spec.TensorSpec(
self._element_shape, self._dtype) == tensor_spec.TensorSpec(
other._element_shape, other._dtype) and
self._dynamic_size == other._dynamic_size and
self._infer_shape == other._infer_shape)
def __hash__(self):
return hash((tensor_spec.TensorSpec(self._element_shape, self._dtype),
self._dynamic_size, self._infer_shape))
@property
def _flat_shapes(self):
# A TensorArray is represented via its variant object, which is a scalar.
return [tensor_shape.scalar()]
@property
def _flat_types(self):
return [dtypes.variant]
def is_compatible_with(self, other):
return (isinstance(other, TensorArrayStructure) and
self._dtype.is_compatible_with(other._dtype) and
self._element_shape.is_compatible_with(other._element_shape) and
self._dynamic_size == other._dynamic_size)
def _to_tensor_list(self, value):
if not isinstance(value, tensor_array_ops.TensorArray):
raise TypeError("value must be a TensorArray, but saw: {}"
.format(type(value)))
if value.flow is not None and value.flow.dtype == dtypes.variant:
return [value.flow]
else:
# Convert to a TF2-style TensorArray.
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
# "implementation / as_variant" arg to TensorArray constructor.
with ops.name_scope("convert_tensor_array"):
flow = list_ops.tensor_list_from_tensor(
tensor=value.stack(), element_shape=value.element_shape)
return [flow]
def _to_batched_tensor_list(self, value):
raise NotImplementedError("TensorArrayStructure._to_batched_tensor_list")
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
raise ValueError("TensorArrayStructure corresponds to a single "
"tf.variant scalar.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
# This will return a TF2 Graph-style TensorArray because flat_value[0] is
# a variant object. size == -1 implies unknown size.
ret = tensor_array_ops.TensorArray(
dtype=self._dtype,
flow=flat_value[0],
dynamic_size=self._dynamic_size,
infer_shape=self._infer_shape)
ret._element_shape = [self._element_shape]
return ret
@staticmethod
def from_value(value):
if not isinstance(value, tensor_array_ops.TensorArray):
raise TypeError("Expected value to be a TensorArray, but saw: {}".
format(type(value)))
return TensorArrayStructure(
dtype=value.dtype,
element_shape=value.element_shape,
dynamic_size=value.dynamic_size,
infer_shape=value._infer_shape)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
# Sneak the dynamic_size and infer_shape values into the legacy shape.
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
.concatenate(self._element_shape))
def _to_legacy_output_classes(self):
return tensor_array_ops.TensorArray
def _batch(self, batch_size):
raise NotImplementedError("TensorArrayStructure._batch")
def _unbatch(self):
raise NotImplementedError("TensorArrayStructure._unbatch")
@tf_export("data.experimental.RaggedTensorStructure")
class RaggedTensorStructure(Structure):
"""Represents structural information about a `tf.RaggedTensor`."""
def __init__(self, dtype, shape, ragged_rank):
self._dtype = dtypes.as_dtype(dtype)
self._shape = tensor_shape.as_shape(shape)
self._ragged_rank = ragged_rank
def __eq__(self, other):
return (isinstance(other, RaggedTensorStructure) and tensor_spec.TensorSpec(
self._shape, self._dtype) == tensor_spec.TensorSpec(
other._shape, other._dtype) and
self._ragged_rank == other._ragged_rank)
def __hash__(self):
return hash((tensor_spec.TensorSpec(self._shape, self._dtype),
self._ragged_rank))
@property
def _flat_shapes(self):
# A list of shapes matching the shapes of `self._to_tensor_list()`.
# NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
# `[]` (scalar), but a `RaggedTensorStructure` can also represent a batch of
# boxed `RaggedTensor` objects with shape `(?)` (and batches of batches,
# etc.), so the flat shape must be unknown.
return [tensor_shape.unknown_shape(None)]
@property
def _flat_types(self):
return [dtypes.variant]
def is_compatible_with(self, other):
return (isinstance(other, RaggedTensorStructure) and
self._dtype.is_compatible_with(other._dtype) and
self._shape.is_compatible_with(other._shape) and
self._ragged_rank == other._ragged_rank)
def _to_tensor_list(self, value):
return [value._to_variant()]
def _to_batched_tensor_list(self, value):
return [value._to_variant(batched_input=True)]
def _from_tensor_list(self, flat_value):
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant):
raise ValueError("RaggedTensorStructure corresponds to a single "
"tf.variant scalar.")
return self._from_compatible_tensor_list(flat_value)
def _from_compatible_tensor_list(self, flat_value):
if self._ragged_rank <= 0:
raise ValueError(
"ragged_rank must be greater than zero. Found ragged_rank: %d" %
self._ragged_rank)
result = ragged_tensor.RaggedTensor._from_variant(
flat_value[0], dtype=self._dtype, output_ragged_rank=self._ragged_rank)
if self._shape.ndims is not None:
outer_dim = tensor_shape.dimension_value(self._shape[0])
if outer_dim is not None:
result.row_splits.set_shape([outer_dim + 1])
result.flat_values.set_shape(
tensor_shape.TensorShape([None]).concatenate(
self._shape[1 + self._ragged_rank:]))
return result
@staticmethod
def from_value(value):
return RaggedTensorStructure(value.dtype, value.shape, value.ragged_rank)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._shape
def _to_legacy_output_classes(self):
return self
def _batch(self, batch_size):
return RaggedTensorStructure(
self._dtype,
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
self._ragged_rank + 1)
def _unbatch(self):
# Note: Any ragged_rank is allowed here because the dataset could be
# subsequently batched again. Errors are handled in
# RaggedTensorStructure._from_compatible_tensor_list()
return RaggedTensorStructure(self._dtype, self._shape[1:],
self._ragged_rank - 1)
# Re-register SparseTensorValue -- it's a subclass of tuple, but we don't
# want the NestedStructure registration to take precedence.
type_spec.register_type_spec_from_value_converter(
sparse_tensor_lib.SparseTensorValue,
sparse_tensor_lib.SparseTensorSpec.from_value)

View File

@ -29,7 +29,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
@ -50,16 +52,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
# pylint: disable=g-long-lambda,protected-access
@parameterized.named_parameters(
("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure,
("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
[dtypes.float32], [[]]),
("TensorArray", lambda: tensor_array_ops.TensorArray(
dtype=dtypes.float32, element_shape=(3,), size=0),
structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
("SparseTensor", lambda: sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
structure.SparseTensorStructure, [dtypes.variant], [None]),
sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
structure.RaggedTensorStructure, [dtypes.variant], [None]),
ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
("Nested_0",
lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
@ -80,13 +82,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
def testFlatStructure(self, value_fn, expected_structure, expected_types,
expected_shapes):
value = value_fn()
s = structure.Structure.from_value(value)
s = type_spec.type_spec_from_value(value)
self.assertIsInstance(s, expected_structure)
self.assertEqual(expected_types, s._flat_types)
self.assertLen(s._flat_shapes, len(expected_shapes))
for expected, actual in zip(expected_shapes, s._flat_shapes):
self.assertTrue(actual.is_compatible_with(expected))
self.assertTrue(
tensor_shape.as_shape(expected).is_compatible_with(actual))
if expected is None:
self.assertEqual(actual.ndims, None)
else:
self.assertEqual(actual.as_list(), expected)
@parameterized.named_parameters(
("Tensor", lambda: constant_op.constant(37.0), lambda: [
@ -162,15 +166,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
original_value = original_value_fn()
compatible_values = compatible_values_fn()
incompatible_values = incompatible_values_fn()
s = structure.Structure.from_value(original_value)
s = type_spec.type_spec_from_value(original_value)
for compatible_value in compatible_values:
self.assertTrue(
s.is_compatible_with(
structure.Structure.from_value(compatible_value)))
type_spec.type_spec_from_value(compatible_value)))
for incompatible_value in incompatible_values:
self.assertFalse(
s.is_compatible_with(
structure.Structure.from_value(incompatible_value)))
type_spec.type_spec_from_value(incompatible_value)))
@parameterized.named_parameters(
("Tensor",
@ -215,8 +219,8 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
def testStructureFromValueEquality(self, value1_fn, value2_fn,
*not_equal_value_fns):
# pylint: disable=g-generic-assert
s1 = structure.Structure.from_value(value1_fn())
s2 = structure.Structure.from_value(value2_fn())
s1 = type_spec.type_spec_from_value(value1_fn())
s2 = type_spec.type_spec_from_value(value2_fn())
self.assertEqual(s1, s1) # check __eq__ operator.
self.assertEqual(s1, s2) # check __eq__ operator.
self.assertFalse(s1 != s1) # check __ne__ operator.
@ -224,7 +228,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
self.assertEqual(hash(s1), hash(s1))
self.assertEqual(hash(s1), hash(s2))
for value_fn in not_equal_value_fns:
s3 = structure.Structure.from_value(value_fn())
s3 = type_spec.type_spec_from_value(value_fn())
self.assertNotEqual(s1, s3) # check __ne__ operator.
self.assertNotEqual(s2, s3) # check __ne__ operator.
self.assertFalse(s1 == s3) # check __eq_ operator.
@ -272,9 +276,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
}),
)
def testHash(self, value1_fn, value2_fn, value3_fn):
s1 = structure.Structure.from_value(value1_fn())
s2 = structure.Structure.from_value(value2_fn())
s3 = structure.Structure.from_value(value3_fn())
s1 = type_spec.type_spec_from_value(value1_fn())
s2 = type_spec.type_spec_from_value(value2_fn())
s3 = type_spec.type_spec_from_value(value3_fn())
self.assertEqual(hash(s1), hash(s1))
self.assertEqual(hash(s1), hash(s2))
self.assertNotEqual(hash(s1), hash(s3))
@ -314,7 +318,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
)
def testRoundTripConversion(self, value_fn):
value = value_fn()
s = structure.Structure.from_value(value)
s = type_spec.type_spec_from_value(value)
def maybe_stack_ta(v):
if isinstance(v, tensor_array_ops.TensorArray):
@ -344,7 +348,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
def preserveStaticShape(self):
rt = ragged_factory_ops.constant([[1, 2], [], [3]])
rt_s = structure.Structure.from_value(rt)
rt_s = type_spec.type_spec_from_value(rt)
rt_after = rt_s._from_tensor_list(rt_s._to_tensor_list(rt))
self.assertEqual(rt_after.row_splits.shape.as_list(),
rt.row_splits.shape.as_list())
@ -352,7 +356,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
st = sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
st_s = structure.Structure.from_value(st)
st_s = type_spec.type_spec_from_value(st)
st_after = st_s._from_tensor_list(st_s._to_tensor_list(st))
self.assertEqual(st_after.indices.shape.as_list(),
[None, 2])
@ -367,19 +371,19 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
# 2. Using one structure to restructre a flattened value with an
# incompatible structure fails.
value_tensor = constant_op.constant(42.0)
s_tensor = structure.Structure.from_value(value_tensor)
s_tensor = type_spec.type_spec_from_value(value_tensor)
flat_tensor = s_tensor._to_tensor_list(value_tensor)
value_sparse_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1])
s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
s_sparse_tensor = type_spec.type_spec_from_value(value_sparse_tensor)
flat_sparse_tensor = s_sparse_tensor._to_tensor_list(value_sparse_tensor)
value_nest = {
"a": constant_op.constant(37.0),
"b": constant_op.constant([1, 2, 3])
}
s_nest = structure.Structure.from_value(value_nest)
s_nest = type_spec.type_spec_from_value(value_nest)
flat_nest = s_nest._to_tensor_list(value_nest)
with self.assertRaisesRegexp(
@ -391,38 +395,34 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
r"dtype.*float32.* and shape \(\)"):
s_tensor._to_tensor_list(value_nest)
with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
with self.assertRaisesRegexp(
TypeError, "Neither a SparseTensor nor SparseTensorValue"):
s_sparse_tensor._to_tensor_list(value_tensor)
with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
with self.assertRaisesRegexp(
TypeError, "Neither a SparseTensor nor SparseTensorValue"):
s_sparse_tensor._to_tensor_list(value_nest)
with self.assertRaisesRegexp(
ValueError, "Tensor.* not compatible with the nested structure "
".*TensorStructure.*TensorStructure"):
".*TensorSpec.*TensorSpec"):
s_nest._to_tensor_list(value_tensor)
with self.assertRaisesRegexp(
ValueError, "SparseTensor.* not compatible with the nested structure "
".*TensorStructure.*TensorStructure"):
".*TensorSpec.*TensorSpec"):
s_nest._to_tensor_list(value_sparse_tensor)
with self.assertRaisesRegexp(
ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
s_tensor._from_tensor_list(flat_sparse_tensor)
with self.assertRaisesRegexp(
ValueError, "TensorStructure corresponds to a single tf.Tensor."):
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
s_tensor._from_tensor_list(flat_nest)
with self.assertRaisesRegexp(
ValueError, "SparseTensorStructure corresponds to a single tf.variant "
"vector of length 3."):
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
s_sparse_tensor._from_tensor_list(flat_tensor)
with self.assertRaisesRegexp(
ValueError, "SparseTensorStructure corresponds to a single tf.variant "
"vector of length 3."):
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
s_sparse_tensor._from_tensor_list(flat_nest)
with self.assertRaisesRegexp(
@ -445,7 +445,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
"a": constant_op.constant(37.0),
"b": constant_op.constant([1, 2, 3])
}
s_0 = structure.Structure.from_value(value_0)
s_0 = type_spec.type_spec_from_value(value_0)
flat_s_0 = s_0._to_tensor_list(value_0)
# `value_1` has compatible nested structure with `value_0`, but different
@ -457,7 +457,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
sparse_tensor.SparseTensor(
indices=[[0, 0]], values=[1], dense_shape=[1, 1])
}
s_1 = structure.Structure.from_value(value_1)
s_1 = type_spec.type_spec_from_value(value_1)
flat_s_1 = s_1._to_tensor_list(value_1)
# `value_2` has incompatible nested structure with `value_0` and `value_1`.
@ -469,27 +469,27 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
sparse_tensor.SparseTensor(
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
}
s_2 = structure.Structure.from_value(value_2)
s_2 = type_spec.type_spec_from_value(value_2)
flat_s_2 = s_2._to_tensor_list(value_2)
with self.assertRaisesRegexp(
ValueError, "SparseTensor.* not compatible with the nested structure "
".*TensorStructure"):
ValueError, ".*SparseTensor.* not compatible with the nested structure "
".*TensorSpec"):
s_0._to_tensor_list(value_1)
with self.assertRaisesRegexp(
ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
"nested structure .*TensorStructure"):
ValueError, ".*SparseTensor.*SparseTensor.* not compatible with the "
"nested structure .*TensorSpec"):
s_0._to_tensor_list(value_2)
with self.assertRaisesRegexp(
ValueError, "Tensor.* not compatible with the nested structure "
".*SparseTensorStructure"):
ValueError, ".*Tensor.* not compatible with the nested structure "
".*SparseTensorSpec"):
s_1._to_tensor_list(value_0)
with self.assertRaisesRegexp(
ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
"nested structure .*TensorStructure"):
ValueError, ".*SparseTensor.*SparseTensor.* not compatible with the "
"nested structure .*TensorSpec"):
s_0._to_tensor_list(value_2)
# NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
@ -497,29 +497,27 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
# adding a deterministic repr for these error messages (among other
# improvements).
with self.assertRaisesRegexp(
ValueError, "Tensor.*Tensor.* not compatible with the nested structure "
".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
"SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
ValueError,
".*Tensor.*Tensor.* not compatible with the nested structure "
".*(TensorSpec.*SparseTensorSpec.*SparseTensorSpec|"
"SparseTensorSpec.*SparseTensorSpec.*TensorSpec)"):
s_2._to_tensor_list(value_0)
with self.assertRaisesRegexp(
ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
"not compatible with the nested structure .*"
"(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
"SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
"(TensorSpec.*SparseTensorSpec.*SparseTensorSpec|"
"SparseTensorSpec.*SparseTensorSpec.*TensorSpec)"):
s_2._to_tensor_list(value_1)
with self.assertRaisesRegexp(
ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
s_0._from_tensor_list(flat_s_1)
with self.assertRaisesRegexp(
ValueError, "Expected 2 flat values in NestedStructure but got 3."):
s_0._from_tensor_list(flat_s_2)
with self.assertRaisesRegexp(
ValueError, "SparseTensorStructure corresponds to a single tf.variant "
"vector of length 3."):
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
s_1._from_tensor_list(flat_s_0)
with self.assertRaisesRegexp(
@ -552,9 +550,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
[True, False, 2, 2]), tensor_array_ops.TensorArray,
structure.TensorArrayStructure(
dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, 2),
structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1),
structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1)),
("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, None),
structure.RaggedTensorStructure(dtypes.int32, [2, None], 1),
structure.RaggedTensorStructure(dtypes.int32, [2, None], 1)),
("Nested", {
"a": dtypes.float32,
"b": (dtypes.int32, dtypes.string)
@ -576,8 +574,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
output_classes, expected_structure):
actual_structure = structure.convert_legacy_structure(
output_types, output_shapes, output_classes)
self.assertTrue(expected_structure.is_compatible_with(actual_structure))
self.assertTrue(actual_structure.is_compatible_with(expected_structure))
self.assertEqual(actual_structure, expected_structure)
def testNestedNestedStructure(self):
# Although `Structure.from_value()` will not construct one, a nested
@ -639,10 +636,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
def testBatch(self, element_structure, batch_size,
expected_batched_structure):
batched_structure = element_structure._batch(batch_size)
self.assertTrue(
batched_structure.is_compatible_with(expected_batched_structure))
self.assertTrue(
expected_batched_structure.is_compatible_with(batched_structure))
self.assertEqual(batched_structure, expected_batched_structure)
@parameterized.named_parameters(
("Tensor", structure.TensorStructure(dtypes.float32, [32]),
@ -656,11 +650,11 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
structure.SparseTensorStructure(dtypes.float32, [None, 4]),
structure.SparseTensorStructure(dtypes.float32, [4])),
("RaggedTensor",
structure.RaggedTensorStructure(dtypes.float32, [32, 4, None], 2),
structure.RaggedTensorStructure(dtypes.float32, [4, None], 1)),
structure.RaggedTensorStructure(dtypes.float32, [32, None, None], 2),
structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
("RaggedTensorUnknown",
structure.RaggedTensorStructure(dtypes.float32, [None, None, 4], 2),
structure.RaggedTensorStructure(dtypes.float32, [None, 4], 1)),
structure.RaggedTensorStructure(dtypes.float32, [None, None, None], 2),
structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
("Nested", structure.NestedStructure({
"a": structure.TensorStructure(dtypes.float32, [128]),
"b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
@ -672,10 +666,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
)
def testUnbatch(self, element_structure, expected_unbatched_structure):
unbatched_structure = element_structure._unbatch()
self.assertTrue(
unbatched_structure.is_compatible_with(expected_unbatched_structure))
self.assertTrue(
expected_unbatched_structure.is_compatible_with(unbatched_structure))
self.assertEqual(unbatched_structure, expected_unbatched_structure)
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
@ -697,7 +688,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
)
def testToBatchedTensorList(self, value_fn, element_0_fn):
batched_value = value_fn()
s = structure.Structure.from_value(batched_value)
s = type_spec.type_spec_from_value(batched_value)
batched_tensor_list = s._to_batched_tensor_list(batched_value)
# The batch dimension is 2 for all of the test cases.

View File

@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_conversion_registry
from tensorflow.python.framework import tensor_like
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
@ -175,6 +176,66 @@ IndexedSlicesValue = collections.namedtuple(
"IndexedSlicesValue", ["values", "indices", "dense_shape"])
# TODO(b/133606651) Export this as tf.IndexedSlicesSpec.
class IndexedSlicesSpec(type_spec.TypeSpec):
"""Type specification for a `tf.IndexedSlices`."""
__slots__ = ["_shape", "_values_dtype", "_indices_dtype",
"_dense_shape_dtype", "_indices_shape"]
value_type = property(lambda self: IndexedSlices)
def __init__(self, shape=None, dtype=dtypes.float32,
indices_dtype=dtypes.int64, dense_shape_dtype=True,
indices_shape=None):
"""Constructs a type specification for a `tf.IndexedSlices`.
Args:
shape: The dense shape of the `IndexedSlices`, or `None` to allow any
dense shape.
dtype: `tf.DType` of values in the `IndexedSlices`.
indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`. One
of `tf.int32` or `tf.int64`.
dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`.
One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has
no `dense_shape` tensor).
indices_shape: The shape of the `indices` component, which indicates
how many slices are in the `IndexedSlices`.
"""
self._shape = tensor_shape.as_shape(shape)
self._values_dtype = dtypes.as_dtype(dtype)
self._indices_dtype = dtypes.as_dtype(indices_dtype)
if dense_shape_dtype is None:
self._dense_shape_dtype = None
else:
self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype)
self._indices_shape = tensor_shape.as_shape(indices_shape)
def _serialize(self):
return (self._shape, self._values_dtype, self._indices_dtype,
self._dense_shape_dtype, self._indices_shape)
@property
def _component_specs(self):
value_shape = self._indices_shape.concatenate(self._shape[1:])
specs = [
tensor_spec.TensorSpec(value_shape, self._values_dtype),
tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)]
if self._dense_shape_dtype is not None:
specs.append(
tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype))
return specs
def _to_components(self, value):
if value.dense_shape is None:
return (value.values, value.indices)
else:
return (value.values, value.indices, value.dense_shape)
def _from_components(self, tensor_list):
return IndexedSlices(*tensor_list)
@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
"""Converts the given object to a `Tensor` or an `IndexedSlices`.

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import collections
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import composite_tensor
@ -28,6 +29,8 @@ from tensorflow.python.framework import tensor_like
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
@ -117,6 +120,7 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
"""
if isinstance(indices, tensor_spec.TensorSpec):
# TODO(b/133606651) Remove this code path -- replaced by TypeSpec.
if not isinstance(values, tensor_spec.TensorSpec):
raise TypeError("Expected values to be a TensorSpec")
if not isinstance(dense_shape, tensor_spec.TensorSpec):
@ -271,6 +275,120 @@ tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
# TODO(b/133606651) Export this as tf.SparseTensorSpec.
class SparseTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.SparseTensor`."""
__slots__ = ["_dense_shape", "_dtype"]
value_type = property(lambda self: SparseTensor)
def __init__(self, dense_shape=None, dtype=dtypes.float32):
"""Constructs a type specification for a `tf.SparseTensor`.
Args:
dense_shape: The dense shape of the `SparseTensor`, or `None` to allow
any dense shape.
dtype: `tf.DType` of values in the `SparseTensor`.
"""
self._dense_shape = tensor_shape.as_shape(dense_shape)
self._dtype = dtypes.as_dtype(dtype)
def _serialize(self):
return (self._dense_shape, self._dtype)
@property
def _component_specs(self):
rank = self._dense_shape.ndims
num_values = None
return [
tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
tensor_spec.TensorSpec([num_values], self._dtype),
tensor_spec.TensorSpec([rank], dtypes.int64)]
def _to_components(self, value):
if isinstance(value, SparseTensorValue):
value = SparseTensor.from_value(value)
return [value.indices, value.values, value.dense_shape]
def _from_components(self, tensor_list):
return SparseTensor(*tensor_list)
# The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
# to (un)box the component tensors in a way that allows for batching &
# unbatching.
@property
def _flat_tensor_specs(self):
# NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
# but a `SparseTensorSpec` can also represent a batch of boxed
# `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
# etc.), so the flat shape must be unknown.
return [tensor_spec.TensorSpec(None, dtypes.variant)]
def _to_tensor_list(self, value):
value = SparseTensor.from_value(value)
return [gen_sparse_ops.serialize_sparse(
value.indices, value.values, value.dense_shape,
out_type=dtypes.variant)]
def _to_batched_tensor_list(self, value):
dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
if self._dense_shape.merge_with(dense_shape).ndims == 0:
raise ValueError(
"Unbatching a sparse tensor is only supported for rank >= 1")
return [gen_sparse_ops.serialize_many_sparse(
value.indices, value.values, value.dense_shape,
out_type=dtypes.variant)]
def _from_compatible_tensor_list(self, tensor_list):
tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
result = SparseTensor(*tensor_list)
rank = self._dense_shape.ndims
result.indices.set_shape([None, rank])
result.dense_shape.set_shape([rank])
return result
def _batch(self, batch_size):
return SparseTensorSpec(
tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape),
self._dtype)
def _unbatch(self):
if self._dense_shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return SparseTensorSpec(self._dense_shape[1:], self._dtype)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._dense_shape
def _to_legacy_output_classes(self):
return SparseTensor
@classmethod
def from_value(cls, value):
if isinstance(value, SparseTensor):
return cls(value.shape, value.dtype)
if isinstance(value, SparseTensorValue):
if isinstance(value.values, np.ndarray):
return cls(value.dense_shape, value.values.dtype)
else:
return cls.from_value(SparseTensor.from_value(value))
else:
raise TypeError("Expected SparseTensor or SparseTensorValue")
# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor
# is updated to define a _type_spec field (since registration will be
# automatic). Do *not* delete the SparseTensorValue registration.
type_spec.register_type_spec_from_value_converter(
SparseTensor, SparseTensorSpec.from_value)
type_spec.register_type_spec_from_value_converter(
SparseTensorValue, SparseTensorSpec.from_value)
@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
"""Converts value to a `SparseTensor` or `Tensor`.

View File

@ -20,15 +20,17 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import type_spec
from tensorflow.python.util.tf_export import tf_export
@tf_export("TensorSpec")
class TensorSpec(object):
class TensorSpec(type_spec.BatchableTypeSpec):
"""Describes a tf.Tensor.
Metadata for describing the `tf.Tensor` objects accepted or returned
@ -97,7 +99,8 @@ class TensorSpec(object):
Returns:
True if spec_or_tensor is compatible with self.
"""
return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
return (isinstance(spec_or_tensor, (TensorSpec, ops.Tensor)) and
self._dtype.is_compatible_with(spec_or_tensor.dtype) and
self._shape.is_compatible_with(spec_or_tensor.shape))
def __repr__(self):
@ -108,17 +111,80 @@ class TensorSpec(object):
return hash((self._shape_tuple, self.dtype))
def __eq__(self, other):
return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access
and self.dtype == other.dtype
and self._name == other._name) # pylint: disable=protected-access
# pylint: disable=protected-access
return (type(self) is type(other) and
self._shape_tuple == other._shape_tuple
and self._dtype == other._dtype
and self._name == other._name)
def __ne__(self, other):
return not self == other
def __reduce__(self):
return TensorSpec, (self._shape, self._dtype, self._name)
value_type = property(lambda self: ops.Tensor)
def most_specific_compatible_type(self, other):
if (type(self) is not type(other)) or (self._dtype != other.dtype):
raise ValueError("Types are not compatible: %r vs %r" % (self, other))
shape = self._shape.most_specific_compatible_shape(other.shape)
name = self._name if self._name == other.name else None
return TensorSpec(shape, self._dtype, name)
def _serialize(self):
return (self._shape, self._dtype, self._name)
_component_specs = property(lambda self: self)
def _to_components(self, value):
try:
value = ops.convert_to_tensor(value, self._dtype)
except (TypeError, ValueError):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
if not value.shape.is_compatible_with(self._shape):
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
"and shape %s." % (value, self._dtype, self._shape))
return value
def _from_components(self, components):
return components
def _from_compatible_tensor_list(self, tensor_list):
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
# op here and return that, instead of mutating the input's shape using
# `Tensor.set_shape()`. However, that would add extra ops, which could
# impact performance. When this bug is resolved, we should be able to add
# the `ensure_shape()` ops and optimize them away using contextual shape
# information.
assert len(tensor_list) == 1
tensor_list[0].set_shape(self._shape)
return tensor_list[0]
def _to_batchable_tensor_list(self, value, batched=False):
if batched and self._shape.merge_with(value.shape).ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return self._to_components(value)
def _batch(self, batch_size):
return TensorSpec(
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
self._dtype)
def _unbatch(self):
if self._shape.ndims == 0:
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
return TensorSpec(self._shape[1:], self._dtype)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._shape
def _to_legacy_output_classes(self):
return ops.Tensor
# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
class BoundedTensorSpec(TensorSpec):
"""A `TensorSpec` that specifies minimum and maximum values.
@ -216,3 +282,15 @@ class BoundedTensorSpec(TensorSpec):
def __reduce__(self):
return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
self._maximum, self._name)
def _serialize(self):
return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
pywrap_tensorflow.RegisterType("TensorSpec", TensorSpec)
# Note: we do not include Tensor names when constructing TypeSpecs.
type_spec.register_type_spec_from_value_converter(
ops.Tensor,
lambda tensor: TensorSpec(tensor.shape, tensor.dtype))

View File

@ -24,14 +24,21 @@ import six
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util.lazy_loader import LazyLoader
from tensorflow.python.util.tf_export import tf_export
# Use LazyLoader to avoid circular dependencies.
tensor_spec = LazyLoader(
"tensor_spec", globals(),
"tensorflow.python.framework.tensor_spec")
ops = LazyLoader(
"ops", globals(),
"tensorflow.python.framework.ops")
# TODO(b/133606651) Export this as "TypeSpec" (or experimental.TypeSpec?) and
# deprecate the tf.data.experimental.Structure endpoint.
@ -81,7 +88,7 @@ class TypeSpec(object):
# compatibility using their `is_comptaible_with` method; and all other
# types are considered compatible if they are equal).
if not isinstance(spec_or_value, TypeSpec):
spec_or_value = TypeSpec.from_value(spec_or_value)
spec_or_value = type_spec_from_value(spec_or_value)
if type(self) is not type(spec_or_value):
return False
return self.__is_compatible(self._serialize(),
@ -111,44 +118,6 @@ class TypeSpec(object):
self._serialize(), other._serialize()) # pylint: disable=protected-access
return self._deserialize(merged)
@staticmethod
def from_value(value):
"""Returns a `TypeSpec` that represents the given `value`.
Args:
value: A value that can be accepted or returned by TensorFlow APIs.
Returns:
A `TypeSpec` that is compatible with `value`.
Raises:
TypeError: If a TypeSpec cannot be built for `value`, because its type
is not supported.
"""
if isinstance(value, ops.Tensor):
# Note: we do not include Tensor names when constructing TypeSpecs.
return tensor_spec.TensorSpec(value.shape, value.dtype)
# TODO(b/133606651) Uncomment the following two lines when CompositeTensor
# is updated to define a _type_spec field:
#
# if isinstance(value, composite_tensor.CompositeTensor):
# return value._type_spec # pylint: disable=protected-access
for entry in reversed(TypeSpec._TYPE_CONVERSION_FUNCTION_REGISTRY):
type_object, converter_fn, allow_subclass = entry
if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck
(allow_subclass and isinstance(value, type_object))):
return converter_fn(value)
# Fallback: try converting value to a tensor.
try:
tensor = ops.convert_to_tensor(value)
except (ValueError, TypeError):
raise TypeError("Could not build a TypeSpec for %r with type %s" %
(value, type(value).__name__))
return TypeSpec.from_value(tensor)
# === Component encoding for values ===
@abc.abstractmethod
@ -331,7 +300,7 @@ class TypeSpec(object):
def __check_tensor_list(self, tensor_list):
expected = self._flat_tensor_specs
specs = [tensor_spec.TensorSpec.from_tensor(t) for t in tensor_list]
specs = [type_spec_from_value(t) for t in tensor_list]
if len(specs) != len(expected):
raise ValueError("Incompatible input: wrong number of tensors")
for i, (s1, s2) in enumerate(zip(specs, expected)):
@ -346,8 +315,7 @@ class TypeSpec(object):
def __make_cmp_key(self, value):
"""Converts `value` to a hashable key."""
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec,
tensor_spec.TensorSpec)):
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
return value
elif isinstance(value, compat.bytes_or_text_types):
return value
@ -386,8 +354,7 @@ class TypeSpec(object):
if isinstance(a, tuple):
return (len(a) == len(b) and
all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
elif isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType,
tensor_spec.TensorSpec)):
elif isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)):
return a.is_compatible_with(b)
else:
return a == b
@ -428,12 +395,6 @@ class TypeSpec(object):
for (x, y) in zip(a, b))
elif isinstance(a, tensor_shape.TensorShape):
return a.most_specific_compatible_shape(b)
elif isinstance(a, tensor_spec.TensorSpec):
if a.dtype != b.dtype:
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
shape = a.shape.most_specific_compatible_shape(b.shape)
name = a.name if a.name == b.name else None
return tensor_spec.TensorSpec(shape, a.dtype, name)
elif isinstance(a, list):
raise AssertionError("_serialize() should not return list values.")
elif isinstance(a, TypeSpec):
@ -491,6 +452,55 @@ class BatchableTypeSpec(TypeSpec):
return tensor_list
def type_spec_from_value(value):
"""Returns a `TypeSpec` that represents the given `value`.
Args:
value: A value that can be accepted or returned by TensorFlow APIs.
Returns:
A `TypeSpec` that is compatible with `value`.
Raises:
TypeError: If a TypeSpec cannot be built for `value`, because its type
is not supported.
"""
spec = _type_spec_from_value(value)
if spec is not None:
return spec
# Fallback: try converting value to a tensor.
try:
tensor = ops.convert_to_tensor(value)
spec = _type_spec_from_value(tensor)
if spec is not None:
return spec
except (ValueError, TypeError):
pass
raise TypeError("Could not build a TypeSpec for %r with type %s" %
(value, type(value).__name__))
def _type_spec_from_value(value):
"""Returns a `TypeSpec` that represents the given `value`."""
if isinstance(value, ops.Tensor):
# Note: we do not include Tensor names when constructing TypeSpecs.
return tensor_spec.TensorSpec(value.shape, value.dtype)
# TODO(b/133606651) Uncomment the following two lines when CompositeTensor
# is updated to define a _type_spec field:
#
# if isinstance(value, composite_tensor.CompositeTensor):
# return value._type_spec # pylint: disable=protected-access
for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
type_object, converter_fn, allow_subclass = entry
if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck
(allow_subclass and isinstance(value, type_object))):
return converter_fn(value)
_TYPE_CONVERSION_FUNCTION_REGISTRY = []

View File

@ -278,5 +278,10 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
(tensor_shape.TensorShape([5, 3]), dtypes.int32,
tensor_shape.TensorShape(None), dtypes.bool, "red"))
def testFromValue(self):
value = TwoTensors([1, 2, 3], [1.0, 2.0], "red")
spec = type_spec.type_spec_from_value(value)
self.assertEqual(spec, TwoTensorsSpec.from_value(value))
if __name__ == "__main__":
googletest.main()

View File

@ -695,6 +695,29 @@ class LSTMV2Test(keras_parameterized.TestCase):
},
input_shape=(num_samples, timesteps, embedding_dim))
def test_bidirectional(self):
batch = 128
timestep = 20
vocab_size = 1000
model = keras.Sequential([
keras.layers.Embedding(vocab_size, 64),
keras.layers.Bidirectional(rnn.LSTM(
64, return_sequences=True)),
keras.layers.Bidirectional(rnn.LSTM(32)),
keras.layers.Dense(64, activation='relu'),
keras.layers.Dense(1, activation='sigmoid')
])
model.compile(loss='binary_crossentropy',
optimizer='adam',
metrics=['accuracy'])
x = np.random.randint(0, vocab_size, size=(batch, timestep))
y = np.random.randint(0, 1, size=(batch))
model.fit(x, y, epochs=1, shuffle=False)
model.evaluate(x, y)
model.predict(x)
class LSTMLayerGraphOnlyTest(test.TestCase):
@ -803,7 +826,7 @@ class LSTMLayerGraphOnlyTest(test.TestCase):
existing_loss = loss_value
class UnifiedLSTMPerformanceTest(test.Benchmark):
class UnifiedLSTMPerformanceTest(test.TestCase):
def _measure_performance(self, test_config, model, x_train, y_train):
batch = test_config['batch']
@ -913,29 +936,29 @@ class UnifiedLSTMPerformanceTest(test.Benchmark):
cudnn_vs_unified = cudnn_sec_per_epoch / unified_lstm_sec_per_epoch
unified_vs_normal = normal_lstm_sec_per_epoch / unified_lstm_sec_per_epoch
self.report_benchmark(name='keras_cudnn_lstm_' + mode,
wall_time=cudnn_sec_per_epoch,
iters=test_config['epoch'],
extras=test_config)
self.report_benchmark(name='keras_unified_lstm_' + mode,
wall_time=unified_lstm_sec_per_epoch,
iters=test_config['epoch'],
extras=test_config)
self.report_benchmark(name='keras_canonical_lstm_' + mode,
wall_time=normal_lstm_sec_per_epoch,
iters=test_config['epoch'],
extras=test_config)
# self.report_benchmark(name='keras_cudnn_lstm_' + mode,
# wall_time=cudnn_sec_per_epoch,
# iters=test_config['epoch'],
# extras=test_config)
# self.report_benchmark(name='keras_unified_lstm_' + mode,
# wall_time=unified_lstm_sec_per_epoch,
# iters=test_config['epoch'],
# extras=test_config)
# self.report_benchmark(name='keras_canonical_lstm_' + mode,
# wall_time=normal_lstm_sec_per_epoch,
# iters=test_config['epoch'],
# extras=test_config)
logging.info('Expect the performance of Unified LSTM is within 80% of '
'CuDNN LSTM, got {0:.2f}%'.format(cudnn_vs_unified * 100))
logging.info('Expect the performance of Unified LSTM is more than 5 times'
' of normal LSTM, got {0:.2f}'.format(unified_vs_normal))
def benchmark_performance_graph(self):
def test_performance_graph(self):
with context.graph_mode(), session_lib.Session(config=_config):
self._benchmark_performance_with_standard_cudnn_impl()
def benchmark_performance_eager(self):
def test_performance_eager(self):
with context.eager_mode():
self._benchmark_performance_with_standard_cudnn_impl()

View File

@ -398,28 +398,23 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
else:
last_output, outputs, new_h, runtime = standard_gru(**normal_gru_kwargs)
else:
api_name = 'gru_' + str(uuid.uuid4())
defun_standard_gru = _generate_defun_backend(
api_name, _CPU_DEVICE_NAME, standard_gru)
defun_cudnn_gru = _generate_defun_backend(
api_name, _GPU_DEVICE_NAME, cudnn_gru)
# Call the normal GRU impl and register the CuDNN impl function. The
# grappler will kick in during session execution to optimize the graph.
last_output, outputs, new_h, runtime = defun_standard_gru(
**normal_gru_kwargs)
def register_cudnn_defun():
function.register(defun_cudnn_gru, **cudnn_gru_kwargs)
# return some dummy value since the tf.cond require some return value.
return 0
if mask is None:
register_cudnn_defun()
last_output, outputs, new_h, runtime = gru_with_backend_selection(
normal_gru_kwargs, cudnn_gru_kwargs)
else:
# Only when seq_right_padded=True, CuDNN kernel can support that
# properly.
control_flow_ops.cond(is_sequence_right_padded(mask, self.time_major),
true_fn=register_cudnn_defun,
false_fn=lambda: 0)
def with_mask_support():
# TODO(b/134702514): Change to use backend selection.
# return gru_with_backend_selection(normal_gru_kwargs,
# cudnn_gru_kwargs)
return standard_gru(**normal_gru_kwargs)
def without_mask_support():
return standard_gru(**normal_gru_kwargs)
last_output, outputs, new_h, runtime = control_flow_ops.cond(
is_sequence_right_padded(mask, self.time_major),
true_fn=with_mask_support,
false_fn=without_mask_support)
states = [new_h]
return last_output, outputs, runtime, states
@ -532,22 +527,20 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
if mask is not None:
sequence_length = calculate_sequence_by_mask(mask, time_major)
if go_backwards:
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length,
seq_axis=0, batch_axis=1)
outputs, h, _, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
rnn_mode='gru', sequence_lengths=sequence_length)
else:
# Fill the array with shape [batch] with value of max timesteps.
sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
array_ops.shape(inputs)[0])
if go_backwards:
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length, seq_axis=0,
batch_axis=1)
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = array_ops.reverse(inputs, axis=[0])
outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
rnn_mode='gru')
outputs, h, _, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
inputs,
input_h=init_h,
input_c=0,
params=params,
is_training=True,
rnn_mode='gru',
sequence_lengths=sequence_length)
last_output = outputs[-1]
if not time_major:
outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
@ -565,6 +558,44 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
return last_output, outputs, h, _runtime(_RUNTIME_GPU)
def gru_with_backend_selection(normal_gru_params, cudnn_gru_params):
"""Call the GRU with optimized backend kernel selection.
Under the hood, this function will create two TF function, one with the most
generic kernel and can run on all device condition, and the second one with
CuDNN specific kernel, which can only run on GPU.
The first function will be called with normal_lstm_params, while the second
function is not called, but only registered in the graph. The Grappler will
do the proper graph rewrite and swap the optimized TF function based on the
device placement.
Args:
normal_gru_params: Dict, parameters for the generic TF function.
cudnn_gru_params: Dict, parameters for the CuDNN specific TF function.
Returns:
List of output tensors, same as standard_gru.
"""
# Each time a `tf.function` is called, we will give it a unique
# identifiable API name, so that Grappler won't get confused when it
# sees multiple GRU layers added into same graph, and it will be able
# to pair up the different implementations across them.
api_name = 'gru_' + str(uuid.uuid4())
defun_standard_gru = _generate_defun_backend(
api_name, _CPU_DEVICE_NAME, standard_gru)
defun_cudnn_gru = _generate_defun_backend(
api_name, _GPU_DEVICE_NAME, cudnn_gru)
# Call the normal GRU impl and register the CuDNN impl function. The
# grappler will kick in during session execution to optimize the graph.
last_output, outputs, new_h, runtime = defun_standard_gru(
**normal_gru_params)
function.register(defun_cudnn_gru, **cudnn_gru_params)
return last_output, outputs, new_h, runtime
@keras_export('keras.layers.LSTMCell', v1=[])
class LSTMCell(recurrent.LSTMCell):
"""Cell class for the LSTM layer.
@ -876,33 +907,25 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
last_output, outputs, new_h, new_c, runtime = standard_lstm(
**normal_lstm_kwargs)
else:
# Each time a `tf.function` is called, we will give it a unique
# identifiable API name, so that Grappler won't get confused when it
# sees multiple LSTM layers added into same graph, and it will be able
# to pair up the different implementations across them.
api_name = 'lstm_' + str(uuid.uuid4())
defun_standard_lstm = _generate_defun_backend(
api_name, _CPU_DEVICE_NAME, standard_lstm)
defun_cudnn_lstm = _generate_defun_backend(
api_name, _GPU_DEVICE_NAME, cudnn_lstm)
# Call the normal LSTM impl and register the CuDNN impl function. The
# grappler will kick in during session execution to optimize the graph.
last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
**normal_lstm_kwargs)
def register_cudnn_defun():
function.register(defun_cudnn_lstm, **cudnn_lstm_kwargs)
# return some dummy value since the tf.cond require some return value.
return 0
if mask is None:
register_cudnn_defun()
(last_output, outputs,
new_h, new_c, runtime) = lstm_with_backend_selection(
normal_lstm_kwargs, cudnn_lstm_kwargs)
else:
# Only when seq_right_padded=True, CuDNN kernel can support that
# properly.
control_flow_ops.cond(is_sequence_right_padded(mask, self.time_major),
true_fn=register_cudnn_defun,
false_fn=lambda: 0)
def with_mask_support():
# TODO(b/134702514): Change to use backend selection.
# return lstm_with_backend_selection(normal_lstm_kwargs,
# cudnn_lstm_kwargs)
return standard_lstm(**normal_lstm_kwargs)
def without_mask_support():
return standard_lstm(**normal_lstm_kwargs)
(last_output, outputs,
new_h, new_c, runtime) = control_flow_ops.cond(
is_sequence_right_padded(mask, self.time_major),
true_fn=with_mask_support,
false_fn=without_mask_support)
states = [new_h, new_c]
if self.stateful:
@ -1076,25 +1099,31 @@ def cudnn_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
# so that mathematically it is same as the canonical LSTM implementation.
full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
if mask is not None:
sequence_length = calculate_sequence_by_mask(mask, time_major)
else:
# Fill the array with shape [batch] with value of max timesteps.
sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
array_ops.shape(inputs)[0])
if go_backwards:
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length, seq_axis=0,
batch_axis=1)
params = _canonical_to_params(
weights=weights,
biases=array_ops.split(full_bias, 8),
shape=constant_op.constant([-1]),
transpose_weights=True)
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
rnn_mode='lstm', sequence_lengths=sequence_length)
if mask is not None:
sequence_length = calculate_sequence_by_mask(mask, time_major)
if go_backwards:
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length,
seq_axis=0, batch_axis=1)
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
rnn_mode='lstm', sequence_lengths=sequence_length)
else:
# # Fill the array with shape [batch] with value of max timesteps.
# sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
# array_ops.shape(inputs)[0])
if go_backwards:
# Reverse axis 0 since the input is already convert to time major.
inputs = array_ops.reverse(inputs, axis=[0])
outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
rnn_mode='lstm')
last_output = outputs[-1]
if not time_major:
outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
@ -1112,6 +1141,44 @@ def cudnn_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
return last_output, outputs, h, c, _runtime(_RUNTIME_GPU)
def lstm_with_backend_selection(normal_lstm_params, cudnn_lstm_params):
"""Call the LSTM with optimized backend kernel selection.
Under the hood, this function will create two TF function, one with the most
generic kernel and can run on all device condition, and the second one with
CuDNN specific kernel, which can only run on GPU.
The first function will be called with normal_lstm_params, while the second
function is not called, but only registered in the graph. The Grappler will
do the proper graph rewrite and swap the optimized TF function based on the
device placement.
Args:
normal_lstm_params: Dict, parameters for the generic TF function.
cudnn_lstm_params: Dict, parameters for the CuDNN specific TF function.
Returns:
List of output tensors, same as standard_lstm.
"""
# Each time a `tf.function` is called, we will give it a unique
# identifiable API name, so that Grappler won't get confused when it
# sees multiple LSTM layers added into same graph, and it will be able
# to pair up the different implementations across them.
api_name = 'lstm_' + str(uuid.uuid4())
defun_standard_lstm = _generate_defun_backend(
api_name, _CPU_DEVICE_NAME, standard_lstm)
defun_cudnn_lstm = _generate_defun_backend(
api_name, _GPU_DEVICE_NAME, cudnn_lstm)
# Call the normal LSTM impl and register the CuDNN impl function. The
# grappler will kick in during session execution to optimize the graph.
last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
**normal_lstm_params)
function.register(defun_cudnn_lstm, **cudnn_lstm_params)
return last_output, outputs, new_h, new_c, runtime
def is_sequence_right_padded(mask, time_major):
"""Check the mask tensor and see if it right padded.
@ -1186,6 +1253,9 @@ def _generate_defun_backend(unique_api_name, preferred_device, func):
function_attributes = {
_DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
_DEFUN_DEVICE_ATTRIBUTE: preferred_device,
# TODO(b/133178886): The function is auto inlined in eager context, which
# make grappler fail to do the optimization. Force it to not inline here.
'_noinline': True,
}
return function.defun_with_attributes(func=func,
attributes=function_attributes)

View File

@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import constant_op
@ -27,6 +29,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@ -241,6 +244,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
"of the factory methods instead (e.g., "
"RaggedTensor.from_row_lengths())")
# TODO(b/133606651) Remove is_tensor_spec code paths -- replaced by TypeSpec
is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
if is_tensor_spec:
if not (isinstance(values, tensor_spec.TensorSpec) or
@ -1916,6 +1920,161 @@ def match_row_splits_dtypes(*tensors, **kwargs):
return tensors
#===============================================================================
# RaggedTensorSpec
#===============================================================================
# TODO(b/133606651) Export this as tf.RaggedTensorSpec.
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
"""Type specification for a `tf.RaggedTensor`."""
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype"]
value_type = property(lambda self: RaggedTensor)
def __init__(self, shape=None, dtype=dtypes.float32, ragged_rank=None,
row_splits_dtype=dtypes.int64):
"""Constructs a type specification for a `tf.RaggedTensor`.
Args:
shape: The shape of the RaggedTensor, or `None` to allow any shape. If
a shape is specified, then all ragged dimensions must have size `None`.
dtype: `tf.DType` of values in the RaggedTensor.
ragged_rank: Python integer, the ragged rank of the RaggedTensor
to be described. Defaults to `shape.ndims - 1`.
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
One of `tf.int32` or `tf.int64`.
"""
self._shape = tensor_shape.as_shape(shape)
self._dtype = dtypes.as_dtype(dtype)
self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
rank = self._shape.ndims
if ragged_rank is None:
if rank is None:
raise ValueError("Must specify ragged_rank or "
"a shape with a known rank.")
ragged_rank = rank - 1
self._ragged_rank = ragged_rank
if not isinstance(self._ragged_rank, int):
raise TypeError("ragged_rank must be an int")
if rank is not None:
if ragged_rank >= rank:
raise ValueError("ragged_rank must be less than rank.")
def _serialize(self):
return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype)
@property
def _component_specs(self):
if self._ragged_rank == 0:
return [tensor_spec.TensorSpec(self._shape, self._dtype)]
flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
self._shape[self._ragged_rank + 1:])
outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
specs = (
[tensor_spec.TensorSpec(flat_values_shape, self._dtype),
tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)] +
[inner_splits_spec for _ in range(self._ragged_rank - 1)])
return specs
def _to_components(self, value):
if is_ragged(value):
return [value.flat_values] + list(value.nested_row_splits)
else:
return [value]
def _from_components(self, tensor_list):
# Currently, Keras converts tensors to numpy and then calls from_components
# with those np.arrays. So if we see np.ndarrays, convert them to tensors.
# TODO(b/133606651) Update Keras to do something different here. Consider
# adding something like TypeSpec.from_numpy_components?
if isinstance(tensor_list[0], np.ndarray):
tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
result = tensor_list[0]
for row_splits in reversed(tensor_list[1:]):
result = RaggedTensor(result, row_splits, internal=True)
return result
# The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
# to (un)box the component tensors in a way that allows for batching &
# unbatching.
@property
def _flat_tensor_specs(self):
# NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
# `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of
# boxed `RaggedTensor` objects with shape `(...)` (and batches of batches,
# etc.), so the flat shape must be unknown.
return [tensor_spec.TensorSpec(None, dtypes.variant)]
def _to_tensor_list(self, value):
# pylint: disable=protected-access
return [value._to_variant(batched_input=False)]
def _to_batched_tensor_list(self, value):
# pylint: disable=protected-access
return [value._to_variant(batched_input=True)]
def _from_compatible_tensor_list(self, tensor_list):
if self._ragged_rank <= 0:
raise ValueError(
"ragged_rank must be non-negative; got %s." % self._ragged_rank)
result = RaggedTensor._from_variant( # pylint: disable=protected-access
tensor_list[0], dtype=self._dtype,
output_ragged_rank=self._ragged_rank)
if self._shape.ndims is not None:
outer_dim = tensor_shape.dimension_value(self._shape[0])
if outer_dim is not None:
result.row_splits.set_shape([outer_dim + 1])
result.flat_values.set_shape(
tensor_shape.TensorShape([None]).concatenate(
self._shape[1 + self._ragged_rank:]))
return result
def _batch(self, batch_size):
return RaggedTensorSpec(
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
self._dtype,
self._ragged_rank + 1)
def _unbatch(self):
# Note: Negative ragged_rank is allowed here because the dataset could
# be subsequently batched again. Errors are handled in
# RaggedTensorSpec._from_compatible_tensor_list()
return RaggedTensorSpec(self._shape[1:], self._dtype,
self._ragged_rank - 1)
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
return self._shape
def _to_legacy_output_classes(self):
return self
@classmethod
def from_value(cls, value):
return cls(shape=value.shape,
dtype=value.values.dtype,
ragged_rank=value.ragged_rank,
row_splits_dtype=value.row_splits.dtype)
# TODO(b/133606651) Delete the RaggedTensor registration when CompositeTensor
# is updated to define a _type_spec field (since registration will be
# automatic). Do *not* delete the RaggedTensorValue registration.
type_spec.register_type_spec_from_value_converter(
RaggedTensor, RaggedTensorSpec.from_value)
type_spec.register_type_spec_from_value_converter(
ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value)
#===============================================================================
# Convert value -> tensor
#===============================================================================

View File

@ -29,7 +29,9 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.framework import type_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_control_flow_ops
@ -1342,3 +1344,115 @@ def _check_dtypes(value, dtype):
"in future versions of TensorFlow. Traceback:\n{}".format(
value, str(value.dtype), str(dtype),
"".join(traceback.format_stack())))
# TODO(b/133606651) Export this as tf.TensorArraySpec.
class TensorArraySpec(type_spec.TypeSpec):
"""Type specification for a `tf.TensorArray`."""
__slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
value_type = property(lambda self: TensorArray)
def __init__(self, element_shape=None, dtype=dtypes.float32,
dynamic_size=False, infer_shape=True):
"""Constructs a type specification for a `tf.TensorArray`.
Args:
element_shape: The shape of each element in the `TensorArray`.
dtype: Data type of the `TensorArray`.
dynamic_size: Whether the `TensorArray` can grow past its initial size.
infer_shape: Whether shape inference is enabled.
"""
self._element_shape = tensor_shape.as_shape(element_shape)
self._dtype = dtypes.as_dtype(dtype)
self._dynamic_size = dynamic_size
self._infer_shape = infer_shape
def is_compatible_with(self, other):
# We check all fields *except* infer_shape.
# TODO(b/133606651) Verify that this is the correct behavior.
# pylint: disable=protected-access
if not isinstance(other, type_spec.TypeSpec):
other = type_spec.type_spec_from_value(other)
return (isinstance(other, TensorArraySpec) and
self._dtype.is_compatible_with(other._dtype) and
self._element_shape.is_compatible_with(other._element_shape) and
self._dynamic_size == other._dynamic_size)
def most_specific_compatible_type(self, other):
# TODO(b/133606651) Verify that this is the correct behavior for combining
# infer_shape values.
# pylint: disable=protected-access
if not self.is_compatible_with(other):
raise ValueError("Types are not compatible")
infer_shape = self._infer_shape or other._infer_shape
return TensorArraySpec(
self._element_shape.most_specific_compatible_shape(
other._element_shape),
self._dtype, self._dynamic_size, infer_shape)
def _serialize(self):
return (self._element_shape, self._dtype, self._dynamic_size,
self._infer_shape)
@property
def _component_specs(self):
return [tensor_spec.TensorSpec([], dtypes.variant)]
def _to_components(self, value):
if not isinstance(value, TensorArray):
raise TypeError("value must be a TensorArray, but saw: {}"
.format(type(value)))
if value.flow is not None and value.flow.dtype == dtypes.variant:
return [value.flow]
else:
# Convert to a TF2-style TensorArray.
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
# "implementation / as_variant" arg to TensorArray constructor.
with ops.name_scope("convert_tensor_array"):
flow = list_ops.tensor_list_from_tensor(
tensor=value.stack(), element_shape=value.element_shape)
return [flow]
def _from_components(self, tensor_list):
# This will return a TF2 Graph-style TensorArray because tensor_list[0] is
# a variant object. size == -1 implies unknown size.
ret = TensorArray(
dtype=self._dtype,
flow=tensor_list[0],
dynamic_size=self._dynamic_size,
infer_shape=self._infer_shape)
ret._element_shape = [self._element_shape] # pylint: disable=protected-access
return ret
@staticmethod
def from_value(value):
if not isinstance(value, TensorArray):
raise TypeError("Expected value to be a TensorArray, but saw: {}".
format(type(value)))
return TensorArraySpec(
dtype=value.dtype,
element_shape=value.element_shape,
dynamic_size=value.dynamic_size,
infer_shape=value._infer_shape) # pylint: disable=protected-access
def _to_legacy_output_types(self):
return self._dtype
def _to_legacy_output_shapes(self):
# Sneak the dynamic_size and infer_shape values into the legacy shape.
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
.concatenate(self._element_shape))
def _to_legacy_output_classes(self):
return TensorArray
# Register the TypeSpec for TensorArray. If TensorArray is updated to be a
# CompositeTensor, then this registration can be deleted.
type_spec.register_type_spec_from_value_converter(
TensorArray, TensorArraySpec.from_value, allow_subclass=True)

View File

@ -1,6 +1,8 @@
path: "tensorflow.TensorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "dtype"
@ -14,6 +16,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
@ -30,4 +36,8 @@ tf_class {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,11 +1,15 @@
path: "tensorflow.data.experimental.DatasetStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_structure\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
@ -13,6 +17,10 @@ tf_class {
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,8 +1,13 @@
path: "tensorflow.data.experimental.NestedStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
@ -15,4 +20,8 @@ tf_class {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,8 +1,12 @@
path: "tensorflow.data.experimental.OptionalStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.OptionalStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
@ -13,6 +17,10 @@ tf_class {
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.RaggedTensorStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.RaggedTensorStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.SparseTensorStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.SparseTensorStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,16 +1,20 @@
path: "tensorflow.data.experimental.Structure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.TensorArrayStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.TensorStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.TensorStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -44,10 +44,6 @@ tf_module {
name: "OptionalStructure"
mtype: "<type \'type\'>"
}
member {
name: "RaggedTensorStructure"
mtype: "<type \'type\'>"
}
member {
name: "RandomDataset"
mtype: "<type \'type\'>"
@ -56,10 +52,6 @@ tf_module {
name: "Reducer"
mtype: "<type \'type\'>"
}
member {
name: "SparseTensorStructure"
mtype: "<type \'type\'>"
}
member {
name: "SqlDataset"
mtype: "<type \'type\'>"
@ -80,14 +72,6 @@ tf_module {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
member {
name: "TensorArrayStructure"
mtype: "<type \'type\'>"
}
member {
name: "TensorStructure"
mtype: "<type \'type\'>"
}
member {
name: "ThreadingOptions"
mtype: "<type \'type\'>"
@ -100,6 +84,22 @@ tf_module {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
}
member_method {
name: "RaggedTensorStructure"
argspec: "args=[\'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "SparseTensorStructure"
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "TensorArrayStructure"
argspec: "args=[\'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "TensorStructure"
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "bucket_by_sequence_length"
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "

View File

@ -1,6 +1,8 @@
path: "tensorflow.TensorSpec"
tf_class {
is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "dtype"
@ -14,6 +16,10 @@ tf_class {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
@ -30,4 +36,8 @@ tf_class {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_tensor\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,11 +1,15 @@
path: "tensorflow.data.experimental.DatasetStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'element_structure\'], varargs=None, keywords=None, defaults=None"
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
@ -13,6 +17,10 @@ tf_class {
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,8 +1,13 @@
path: "tensorflow.data.experimental.NestedStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
@ -15,4 +20,8 @@ tf_class {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,8 +1,12 @@
path: "tensorflow.data.experimental.OptionalStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.OptionalStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
@ -13,6 +17,10 @@ tf_class {
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.RaggedTensorStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.RaggedTensorStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.SparseTensorStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.SparseTensorStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,16 +1,20 @@
path: "tensorflow.data.experimental.Structure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "value_type"
mtype: "<class \'abc.abstractproperty\'>"
}
member_method {
name: "__init__"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
name: "is_compatible_with"
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
name: "most_specific_compatible_type"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.TensorArrayStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,18 +0,0 @@
path: "tensorflow.data.experimental.TensorStructure"
tf_class {
is_instance: "<class \'tensorflow.python.data.util.structure.TensorStructure\'>"
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "from_value"
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "is_compatible_with"
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -44,10 +44,6 @@ tf_module {
name: "OptionalStructure"
mtype: "<type \'type\'>"
}
member {
name: "RaggedTensorStructure"
mtype: "<type \'type\'>"
}
member {
name: "RandomDataset"
mtype: "<type \'type\'>"
@ -56,10 +52,6 @@ tf_module {
name: "Reducer"
mtype: "<type \'type\'>"
}
member {
name: "SparseTensorStructure"
mtype: "<type \'type\'>"
}
member {
name: "SqlDataset"
mtype: "<type \'type\'>"
@ -80,14 +72,6 @@ tf_module {
name: "TFRecordWriter"
mtype: "<type \'type\'>"
}
member {
name: "TensorArrayStructure"
mtype: "<type \'type\'>"
}
member {
name: "TensorStructure"
mtype: "<type \'type\'>"
}
member {
name: "ThreadingOptions"
mtype: "<type \'type\'>"
@ -100,6 +84,22 @@ tf_module {
name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
}
member_method {
name: "RaggedTensorStructure"
argspec: "args=[\'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "SparseTensorStructure"
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "TensorArrayStructure"
argspec: "args=[\'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "TensorStructure"
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "bucket_by_sequence_length"
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "