Merge pull request #29670 from goldiegadde/ggadde-cp5
Cherrypick important fixes to r2.0 branch.
This commit is contained in:
commit
d08e899087
@ -140,6 +140,8 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
|
||||
|
||||
for (const auto& attr_name : function_attribute_names) {
|
||||
string function_name = node_def->attr().at(attr_name).func().name();
|
||||
// Skip the function if its already optimized by function optimizer.
|
||||
if (::absl::StrContains(function_name, "_specialized_for_")) continue;
|
||||
std::vector<string> equiv_func_names;
|
||||
TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
|
||||
function_name, &equiv_func_names));
|
||||
@ -153,7 +155,8 @@ Status ImplementationSelector::MaybeOptimizeFunctionCall(
|
||||
}
|
||||
}
|
||||
|
||||
if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
|
||||
if (lib_info_->GetApiInfo(node_def->op()) != nullptr &&
|
||||
!::absl::StrContains(node_def->op(), "_specialized_for_")) {
|
||||
std::vector<string> equiv_func_names;
|
||||
TF_RETURN_IF_ERROR(lib_info_->GetEquivalentImplementations(
|
||||
node_def->op(), &equiv_func_names));
|
||||
|
@ -328,6 +328,33 @@ class FromSavedModelTest(TestModels):
|
||||
self.assertIn('This converter can only convert a single ConcreteFunction',
|
||||
str(error.exception))
|
||||
|
||||
def testKerasSequentialModel(self):
|
||||
"""Test a simple sequential tf.Keras model."""
|
||||
self.skipTest('b/134660903')
|
||||
input_data = constant_op.constant(1., shape=[1, 1])
|
||||
|
||||
x = np.array([[1.], [2.]])
|
||||
y = np.array([[2.], [4.]])
|
||||
|
||||
model = keras.models.Sequential([
|
||||
keras.layers.Dropout(0.2),
|
||||
keras.layers.Dense(1),
|
||||
])
|
||||
model.compile(optimizer='sgd', loss='mean_squared_error')
|
||||
model.fit(x, y, epochs=1)
|
||||
|
||||
save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
|
||||
save(model, save_dir)
|
||||
|
||||
# Convert model and ensure model is not None.
|
||||
converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
|
||||
tflite_model = converter.convert()
|
||||
|
||||
# Check values from converted model.
|
||||
expected_value = model.predict(input_data)
|
||||
actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
|
||||
self.assertEqual(expected_value, actual_value)
|
||||
|
||||
|
||||
class FromKerasModelTest(TestModels):
|
||||
|
||||
@ -337,11 +364,13 @@ class FromKerasModelTest(TestModels):
|
||||
input_data = constant_op.constant(1., shape=[1, 1])
|
||||
|
||||
# Create a simple Keras model.
|
||||
x = [-1, 0, 1, 2, 3, 4]
|
||||
y = [-3, -1, 1, 3, 5, 7]
|
||||
x = np.array([[1.], [2.]])
|
||||
y = np.array([[2.], [4.]])
|
||||
|
||||
model = keras.models.Sequential(
|
||||
[keras.layers.Dense(units=1, input_shape=[1])])
|
||||
model = keras.models.Sequential([
|
||||
keras.layers.Dropout(0.2),
|
||||
keras.layers.Dense(units=1, input_shape=[1])
|
||||
])
|
||||
model.compile(optimizer='sgd', loss='mean_squared_error')
|
||||
model.fit(x, y, epochs=1)
|
||||
|
||||
|
@ -958,6 +958,7 @@ py_library(
|
||||
":tensor_shape",
|
||||
":tf2",
|
||||
":traceable_stack",
|
||||
":type_spec",
|
||||
":util",
|
||||
":versions",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
@ -1138,6 +1139,7 @@ py_library(
|
||||
":framework_ops",
|
||||
":tensor_like",
|
||||
":tensor_util",
|
||||
":type_spec",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1244,6 +1246,7 @@ py_library(
|
||||
":common_shapes",
|
||||
":dtypes",
|
||||
":tensor_shape",
|
||||
":type_spec",
|
||||
":util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
|
@ -23,6 +23,7 @@ from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -40,7 +41,7 @@ class _ScanDataset(dataset_ops.UnaryDataset):
|
||||
# Compute initial values for the state classes, shapes and types based on
|
||||
# the initial state. The shapes may be refined by running `tf_scan_func` one
|
||||
# or more times below.
|
||||
self._state_structure = structure.Structure.from_value(self._initial_state)
|
||||
self._state_structure = type_spec.type_spec_from_value(self._initial_state)
|
||||
|
||||
# Iteratively rerun the scan function until reaching a fixed point on
|
||||
# `self._state_shapes`.
|
||||
|
@ -38,6 +38,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
@ -290,7 +291,7 @@ class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
)
|
||||
def testDatasetStructure(self, tf_value_fn, expected_element_structure):
|
||||
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
|
||||
dataset_structure = structure.Structure.from_value(dataset)
|
||||
dataset_structure = type_spec.type_spec_from_value(dataset)
|
||||
self.assertIsInstance(dataset_structure, dataset_ops.DatasetStructure)
|
||||
|
||||
# TODO(b/110122868): Add a public API to `tf.data.Dataset` for accessing
|
||||
|
@ -34,6 +34,7 @@ from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -279,7 +280,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertTrue(
|
||||
opt.value_structure.is_compatible_with(expected_value_structure))
|
||||
|
||||
opt_structure = structure.Structure.from_value(opt)
|
||||
opt_structure = type_spec.type_spec_from_value(opt)
|
||||
self.assertIsInstance(opt_structure, optional_ops.OptionalStructure)
|
||||
self.assertTrue(opt_structure.is_compatible_with(opt_structure))
|
||||
self.assertTrue(opt_structure._value_structure.is_compatible_with(
|
||||
@ -289,7 +290,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
# All OptionalStructure objects are not compatible with a non-optional
|
||||
# value.
|
||||
non_optional_structure = structure.Structure.from_value(
|
||||
non_optional_structure = type_spec.type_spec_from_value(
|
||||
constant_op.constant(42.0))
|
||||
self.assertFalse(opt_structure.is_compatible_with(non_optional_structure))
|
||||
|
||||
@ -340,7 +341,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertIsInstance(next_elem, optional_ops.Optional)
|
||||
self.assertTrue(
|
||||
next_elem.value_structure.is_compatible_with(
|
||||
structure.Structure.from_value(tf_value_fn())))
|
||||
type_spec.type_spec_from_value(tf_value_fn())))
|
||||
self.assertTrue(next_elem.has_value())
|
||||
self._assertElementValueEqual(np_value, next_elem.get_value())
|
||||
# After exhausting the iterator, `next_elem.has_value()` will evaluate to
|
||||
@ -356,7 +357,7 @@ class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertIsInstance(next_elem, optional_ops.Optional)
|
||||
self.assertTrue(
|
||||
next_elem.value_structure.is_compatible_with(
|
||||
structure.Structure.from_value(tf_value_fn())))
|
||||
type_spec.type_spec_from_value(tf_value_fn())))
|
||||
# Before initializing the iterator, evaluating the optional fails with
|
||||
# a FailedPreconditionError. This is only relevant in graph mode.
|
||||
elem_has_value_t = next_elem.has_value()
|
||||
|
@ -54,6 +54,7 @@ from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
@ -1391,7 +1392,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
|
||||
with ops.name_scope("initial_state"):
|
||||
initial_state = structure_lib.normalize_tensors(initial_state)
|
||||
state_structure = structure_lib.Structure.from_value(initial_state)
|
||||
state_structure = type_spec.type_spec_from_value(initial_state)
|
||||
|
||||
# Iteratively rerun the reduce function until reaching a fixed point on
|
||||
# `state_structure`.
|
||||
@ -2264,7 +2265,7 @@ class TensorDataset(DatasetSource):
|
||||
def __init__(self, tensors):
|
||||
"""See `Dataset.from_tensors()` for details."""
|
||||
tensors = structure_lib.normalize_tensors(tensors)
|
||||
self._structure = structure_lib.Structure.from_value(tensors)
|
||||
self._structure = type_spec.type_spec_from_value(tensors)
|
||||
self._tensors = self._structure._to_tensor_list(tensors) # pylint: disable=protected-access
|
||||
|
||||
variant_tensor = gen_dataset_ops.tensor_dataset(
|
||||
@ -2284,7 +2285,7 @@ class TensorSliceDataset(DatasetSource):
|
||||
with ops.name_scope("tensors"):
|
||||
tensors = structure_lib.normalize_tensors(tensors)
|
||||
|
||||
batched_structure = structure_lib.Structure.from_value(tensors)
|
||||
batched_structure = type_spec.type_spec_from_value(tensors)
|
||||
# pylint: disable=protected-access
|
||||
self._tensors = batched_structure._to_batched_tensor_list(tensors)
|
||||
self._structure = batched_structure._unbatch()
|
||||
@ -2377,51 +2378,33 @@ def to_variant(dataset):
|
||||
return dataset._variant_tensor # pylint: disable=protected-access
|
||||
|
||||
|
||||
# TODO(b/133606651) Rename this class to DatasetSpec
|
||||
@tf_export("data.experimental.DatasetStructure")
|
||||
class DatasetStructure(structure_lib.Structure):
|
||||
"""Represents a `Dataset` of structured values."""
|
||||
class DatasetStructure(type_spec.TypeSpec):
|
||||
"""Type specification for `tf.data.Dataset`."""
|
||||
|
||||
def __init__(self, element_structure):
|
||||
self._element_structure = element_structure
|
||||
__slots__ = ["_element_structure"]
|
||||
|
||||
def __eq__(self, other):
|
||||
# pylint: disable=protected-access
|
||||
return (isinstance(other, DatasetStructure) and
|
||||
self._element_structure == other._element_structure)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._element_structure)
|
||||
def __init__(self, element_spec):
|
||||
self._element_structure = element_spec
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
return [tensor_shape.scalar()]
|
||||
def value_type(self):
|
||||
return _VariantDataset
|
||||
|
||||
def _serialize(self):
|
||||
return (self._element_structure,)
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
def _component_specs(self):
|
||||
return tensor_spec.TensorSpec([], dtypes.variant)
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
def _to_components(self, value):
|
||||
return value._variant_tensor # pylint: disable=protected-access
|
||||
|
||||
def _from_components(self, components):
|
||||
# pylint: disable=protected-access
|
||||
return (isinstance(other, DatasetStructure) and
|
||||
self._element_structure.is_compatible_with(
|
||||
other._element_structure))
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return [value._variant_tensor] # pylint: disable=protected-access
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
|
||||
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
|
||||
raise ValueError(
|
||||
"DatasetStructure corresponds to a single tf.variant scalar.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
# pylint: disable=protected-access
|
||||
return _VariantDataset(flat_value[0], self._element_structure)
|
||||
return _VariantDataset(components, self._element_structure)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
@ -2436,17 +2419,11 @@ class DatasetStructure(structure_lib.Structure):
|
||||
def _to_legacy_output_classes(self):
|
||||
return self
|
||||
|
||||
def _batch(self, batch_size):
|
||||
raise NotImplementedError("Batching for `tf.data.Dataset` objects.")
|
||||
|
||||
def _unbatch(self):
|
||||
raise NotImplementedError("Unbatching for `tf.data.Dataset` objects.")
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
structure_lib.Structure._register_custom_converter(DatasetV2,
|
||||
DatasetStructure.from_value)
|
||||
# pylint: enable=protected-access
|
||||
# TODO(b/133606651) Delete this registration when CompositeTensor is updated
|
||||
# to define a _type_spec field (since registration will be automatic).
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
DatasetV2, DatasetStructure.from_value, allow_subclass=True)
|
||||
|
||||
|
||||
class StructuredFunctionWrapper(object):
|
||||
@ -2565,7 +2542,7 @@ class StructuredFunctionWrapper(object):
|
||||
ret = tuple(ret)
|
||||
|
||||
try:
|
||||
self._output_structure = structure_lib.Structure.from_value(ret)
|
||||
self._output_structure = type_spec.type_spec_from_value(ret)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Unsupported return value from function passed to "
|
||||
"%s: %s." % (transformation_name, ret))
|
||||
@ -2601,12 +2578,7 @@ class StructuredFunctionWrapper(object):
|
||||
# TODO(b/124254153): Enable autograph once the overhead is low enough.
|
||||
# TODO(mdan): Make sure autograph recurses into _wrapper_helper when on.
|
||||
@eager_function.defun_with_attributes(
|
||||
input_signature=[
|
||||
tensor_spec.TensorSpec(input_shape, input_type) # pylint: disable=g-complex-comprehension
|
||||
for input_shape, input_type in zip(
|
||||
self._input_structure._flat_shapes,
|
||||
self._input_structure._flat_types)
|
||||
],
|
||||
input_signature=self._input_structure._flat_tensor_specs,
|
||||
autograph=False,
|
||||
attributes=defun_kwargs)
|
||||
def wrapper_fn(*args): # pylint: disable=missing-docstring
|
||||
@ -2697,7 +2669,7 @@ class _GeneratorDataset(DatasetSource):
|
||||
"""
|
||||
self._init_args = init_args
|
||||
|
||||
self._init_structure = structure_lib.Structure.from_value(init_args)
|
||||
self._init_structure = type_spec.type_spec_from_value(init_args)
|
||||
|
||||
self._init_func = StructuredFunctionWrapper(
|
||||
init_func,
|
||||
|
@ -21,11 +21,12 @@ import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.util import structure
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -93,7 +94,7 @@ class Optional(composite_tensor.CompositeTensor):
|
||||
"""
|
||||
with ops.name_scope("optional") as scope:
|
||||
with ops.name_scope("value"):
|
||||
value_structure = structure.Structure.from_value(value)
|
||||
value_structure = type_spec.type_spec_from_value(value)
|
||||
encoded_value = value_structure._to_tensor_list(value) # pylint: disable=protected-access
|
||||
|
||||
return _OptionalImpl(
|
||||
@ -167,49 +168,31 @@ class _OptionalImpl(Optional):
|
||||
return hasattr(self._variant_tensor, "graph")
|
||||
|
||||
|
||||
# TODO(b/133606651) Rename this class to OptionalSpec
|
||||
@tf_export("data.experimental.OptionalStructure")
|
||||
class OptionalStructure(structure.Structure):
|
||||
class OptionalStructure(type_spec.TypeSpec):
|
||||
"""Represents an optional potentially containing a structured value."""
|
||||
|
||||
__slots__ = ["_value_structure"]
|
||||
|
||||
def __init__(self, value_structure):
|
||||
self._value_structure = value_structure
|
||||
|
||||
def __eq__(self, other):
|
||||
# pylint: disable=protected-access
|
||||
return (isinstance(other, OptionalStructure) and
|
||||
self._value_structure == other._value_structure)
|
||||
@property
|
||||
def value_type(self):
|
||||
return _OptionalImpl
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._value_structure)
|
||||
def _serialize(self):
|
||||
return (self._value_structure,)
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
return [tensor_shape.scalar()]
|
||||
def _component_specs(self):
|
||||
return [tensor_spec.TensorSpec((), dtypes.variant)]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
# pylint: disable=protected-access
|
||||
return (isinstance(other, OptionalStructure) and
|
||||
self._value_structure.is_compatible_with(other._value_structure))
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
def _to_components(self, value):
|
||||
return [value._variant_tensor] # pylint: disable=protected-access
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
raise NotImplementedError(
|
||||
"Unbatching for `tf.data.experimental.Optional` objects.")
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
|
||||
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
|
||||
raise ValueError(
|
||||
"OptionalStructure corresponds to a single tf.variant scalar.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
def _from_components(self, flat_value):
|
||||
# pylint: disable=protected-access
|
||||
return _OptionalImpl(flat_value[0], self._value_structure)
|
||||
|
||||
@ -226,16 +209,8 @@ class OptionalStructure(structure.Structure):
|
||||
def _to_legacy_output_classes(self):
|
||||
return self
|
||||
|
||||
def _batch(self, batch_size):
|
||||
raise NotImplementedError(
|
||||
"Batching for `tf.data.experimental.Optional` objects.")
|
||||
|
||||
def _unbatch(self):
|
||||
raise NotImplementedError(
|
||||
"Unbatching for `tf.data.experimental.Optional` objects.")
|
||||
|
||||
|
||||
# pylint: disable=protected-access
|
||||
structure.Structure._register_custom_converter(Optional,
|
||||
OptionalStructure.from_value)
|
||||
# pylint: enable=protected-access
|
||||
type_spec.register_type_spec_from_value_converter(Optional,
|
||||
OptionalStructure.from_value)
|
||||
type_spec.register_type_spec_from_value_converter(_OptionalImpl,
|
||||
OptionalStructure.from_value)
|
||||
|
@ -17,263 +17,55 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import list_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
_STRUCTURE_CONVERSION_FUNCTION_REGISTRY = {}
|
||||
|
||||
|
||||
@tf_export("data.experimental.Structure")
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Structure(object):
|
||||
"""Represents structural information, such as type and shape, about a value.
|
||||
# Define backwards-compatiblity wrappers for using TypeSpec and its subclasses
|
||||
# to replace Structure and its subclasses. Note that the constructor argument
|
||||
# order is different in many cases -- in particular, TypeSpec follows TensorSpec
|
||||
# and uses the order (shape, dtype); but most Structure subclasses use the
|
||||
# order (dtype, shape).
|
||||
#
|
||||
# TODO(b/133606651) Update tf.data to use TypeSpec directly, and then remove
|
||||
# these compatibility wrappers.
|
||||
|
||||
A `Structure` generalizes the `tf.Tensor.dtype` and `tf.Tensor.shape`
|
||||
properties, so that we can define generic containers of objects including:
|
||||
|
||||
* `tf.Tensor`
|
||||
* `tf.SparseTensor`
|
||||
* Nested structures of the above.
|
||||
Structure = type_spec.TypeSpec
|
||||
|
||||
TODO(b/110122868): In the future, a single `Structure` will replace the
|
||||
`tf.data.Dataset.output_types`, `tf.data.Dataset.output_shapes`,
|
||||
and `tf.data.Dataset.output_classes`, and similar properties and arguments in
|
||||
the `tf.compat.v1.data.Iterator` and `Optional` classes.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __eq__(self, other):
|
||||
"""Returns the this structure and the input structure are equal.
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
Args:
|
||||
other: the structure to use for equality check
|
||||
|
||||
Returns:
|
||||
`True` if this and the input structure are equal and `False` otherwise.
|
||||
"""
|
||||
raise NotImplementedError("Structure.__eq__()")
|
||||
@tf_export("data.experimental.TensorStructure")
|
||||
def TensorStructure(dtype, shape):
|
||||
return tensor_spec.TensorSpec(shape, dtype)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
@abc.abstractmethod
|
||||
def __hash__(self):
|
||||
"""Returns the hash of this structure.
|
||||
@tf_export("data.experimental.SparseTensorStructure")
|
||||
def SparseTensorStructure(dtype, shape):
|
||||
return sparse_tensor_lib.SparseTensorSpec(shape, dtype)
|
||||
|
||||
Returns:
|
||||
The hash of this structure.
|
||||
"""
|
||||
raise NotImplementedError("Structure.__hash__()")
|
||||
|
||||
@abc.abstractproperty
|
||||
def _flat_shapes(self):
|
||||
"""A list of shapes matching the shapes of `self._to_tensor_list()`.
|
||||
@tf_export("data.experimental.TensorArrayStructure")
|
||||
def TensorArrayStructure(dtype, element_shape, dynamic_size, infer_shape):
|
||||
return tensor_array_ops.TensorArraySpec(element_shape, dtype,
|
||||
dynamic_size, infer_shape)
|
||||
|
||||
Returns:
|
||||
A list of `tf.TensorShape` objects.
|
||||
"""
|
||||
raise NotImplementedError("Structure._flat_shapes")
|
||||
|
||||
@abc.abstractproperty
|
||||
def _flat_types(self):
|
||||
"""A list of types matching the types of `self._to_tensor_list()`.
|
||||
|
||||
Returns:
|
||||
A list of `tf.DType` objects.
|
||||
"""
|
||||
raise NotImplementedError("Structure._flat_shapes")
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_compatible_with(self, other):
|
||||
"""Returns `True` if `other` is compatible with this structure.
|
||||
|
||||
A structure `t` is a "subtype" of `s` if:
|
||||
|
||||
* `s` and `t` are instances of the same `Structure` subclass.
|
||||
* The nested structures (if any) of `s` and `t` are the same, according to
|
||||
`tf.nest.assert_same_structure`, and each nested
|
||||
structure of `t` is a "subtype" of the corresponding nested structure of
|
||||
`s`.
|
||||
* Any `tf.DType` components of `t` are the same as the corresponding
|
||||
components in `s`.
|
||||
* Any `tf.TensorShape` components of `t` are compatible with the
|
||||
corresponding components in `s`, according to
|
||||
`tf.TensorShape.is_compatible_with`.
|
||||
|
||||
Args:
|
||||
other: A `Structure`.
|
||||
|
||||
Returns:
|
||||
`True` if `other` is a subtype of this structure, otherwise `False`.
|
||||
"""
|
||||
raise NotImplementedError("Structure.is_compatible_with()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _to_tensor_list(self, value):
|
||||
"""Returns a flat list of `tf.Tensor` representing `value`.
|
||||
|
||||
This method can be used, along with `self._flat_shapes` and
|
||||
`self._flat_types` to represent structured values in lower level APIs
|
||||
(such as plain TensorFlow operations) that do not understand structure.
|
||||
|
||||
Requires: `self.is_compatible_with(Structure.from_value(value))`.
|
||||
|
||||
Args:
|
||||
value: A value with compatible structure.
|
||||
|
||||
Returns:
|
||||
A flat list of `tf.Tensor` representing `value`.
|
||||
"""
|
||||
raise NotImplementedError("Structure._to_tensor_list()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _to_batched_tensor_list(self, value):
|
||||
"""Returns a flat list of rank >= 1 `tf.Tensor` representing `value`.
|
||||
|
||||
This method can be used, along with `self._flat_shapes` and
|
||||
`self._flat_types` to represent structured values in lower level APIs
|
||||
(such as plain TensorFlow operations) that do not understand structure,
|
||||
*and* that require that the plain tensors have a rank of at least one
|
||||
(e.g. for the purpose of slicing the tensors).
|
||||
|
||||
Requires: `self.is_compatible_with(Structure.from_value(value))`.
|
||||
|
||||
Args:
|
||||
value: A value with compatible structure.
|
||||
|
||||
Returns:
|
||||
A flat list of `tf.Tensor` representing `value`.
|
||||
"""
|
||||
raise NotImplementedError("Structure._to_batched_tensor_list()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _from_tensor_list(self, flat_value):
|
||||
"""Builds a flat list of `tf.Tensor` into a value matching this structure.
|
||||
|
||||
Args:
|
||||
flat_value: A list of `tf.Tensor` with compatible flat structure.
|
||||
|
||||
Returns:
|
||||
A structured object matching this structure.
|
||||
|
||||
Raises:
|
||||
ValueError: If the shapes and types of the tensors in `flat_value` are not
|
||||
compatible with `self._flat_shapes` and `self._flat_types` respectively.
|
||||
"""
|
||||
raise NotImplementedError("Structure._from_tensor_list()")
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
"""A version of `_from_tensor_list()` that may avoid performing checks.
|
||||
|
||||
NOTE: This method should be used to avoid checks for performance reasons,
|
||||
when the validity of `flat_value` has been validated by other means.
|
||||
The shapes and types of the tensors in `flat_value` must be compatible with
|
||||
`self._flat_shapes` and `self._flat_types` respectively. The behavior is
|
||||
undefined if this requirement is not met.
|
||||
|
||||
Args:
|
||||
flat_value: A list of `tf.Tensor` with compatible flat structure.
|
||||
|
||||
Returns:
|
||||
A structured object matching this structure.
|
||||
"""
|
||||
return self._from_tensor_list(flat_value)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _batch(self, batch_size):
|
||||
"""Returns a structure representing a batch of objects with this structure.
|
||||
|
||||
Args:
|
||||
batch_size: An `int` representing the number of elements in a batch,
|
||||
or `None` if the batch size may vary.
|
||||
|
||||
Returns:
|
||||
A `Structure` representing a batch of objects with this structure.
|
||||
"""
|
||||
raise NotImplementedError("Structure._batch()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _unbatch(self):
|
||||
raise NotImplementedError("Structure._unbatch()")
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
"""Returns a `Structure` that represents the given `value`.
|
||||
|
||||
Args:
|
||||
value: A potentially structured value.
|
||||
|
||||
Returns:
|
||||
A `Structure` that is compatible with `value`.
|
||||
|
||||
Raises:
|
||||
TypeError: If a structure cannot be built for `value`, because its type
|
||||
or one of its component types is not supported.
|
||||
"""
|
||||
# TODO(b/110122868): Add support for custom types and Dataset to this
|
||||
# method.
|
||||
if isinstance(
|
||||
value,
|
||||
(sparse_tensor_lib.SparseTensor, sparse_tensor_lib.SparseTensorValue)):
|
||||
return SparseTensorStructure.from_value(value)
|
||||
elif isinstance(value, tensor_array_ops.TensorArray):
|
||||
return TensorArrayStructure.from_value(value)
|
||||
elif isinstance(
|
||||
value,
|
||||
(ragged_tensor.RaggedTensor, ragged_tensor_value.RaggedTensorValue)):
|
||||
return RaggedTensorStructure.from_value(value)
|
||||
elif isinstance(value, (tuple, dict)):
|
||||
return NestedStructure.from_value(value)
|
||||
else:
|
||||
for converter_type, converter_fn in (
|
||||
_STRUCTURE_CONVERSION_FUNCTION_REGISTRY.items()):
|
||||
if isinstance(value, converter_type):
|
||||
return converter_fn(value)
|
||||
try:
|
||||
tensor = ops.convert_to_tensor(value)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Could not build a structure for %r" % value)
|
||||
return TensorStructure.from_value(tensor)
|
||||
|
||||
@staticmethod
|
||||
def _register_custom_converter(type_object, converter_fn):
|
||||
"""Registers `converter_fn` for converting values of the given type.
|
||||
|
||||
Args:
|
||||
type_object: A Python `type` object representing the type of values
|
||||
accepted by `converter_fn`.
|
||||
converter_fn: A function that takes one argument (an instance of the
|
||||
type represented by `type_object`) and returns a `Structure`.
|
||||
"""
|
||||
_STRUCTURE_CONVERSION_FUNCTION_REGISTRY[type_object] = converter_fn
|
||||
|
||||
@abc.abstractmethod
|
||||
def _to_legacy_output_types(self):
|
||||
raise NotImplementedError("Structure._to_legacy_output_types()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _to_legacy_output_shapes(self):
|
||||
raise NotImplementedError("Structure._to_legacy_output_shapes()")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _to_legacy_output_classes(self):
|
||||
raise NotImplementedError("Structure._to_legacy_output_classes()")
|
||||
@tf_export("data.experimental.RaggedTensorStructure")
|
||||
def RaggedTensorStructure(dtype, shape, ragged_rank):
|
||||
return ragged_tensor.RaggedTensorSpec(shape, dtype, ragged_rank)
|
||||
|
||||
|
||||
def normalize_tensors(tensors):
|
||||
@ -340,7 +132,7 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
|
||||
flat_ret = []
|
||||
for flat_type, flat_shape, flat_class in zip(flat_types, flat_shapes,
|
||||
flat_classes):
|
||||
if isinstance(flat_class, Structure):
|
||||
if isinstance(flat_class, type_spec.TypeSpec):
|
||||
flat_ret.append(flat_class)
|
||||
elif issubclass(flat_class, sparse_tensor_lib.SparseTensor):
|
||||
flat_ret.append(SparseTensorStructure(flat_type, flat_shape))
|
||||
@ -361,88 +153,102 @@ def convert_legacy_structure(output_types, output_shapes, output_classes):
|
||||
"Could not build a structure for output class %r" % (flat_class,))
|
||||
|
||||
ret = nest.pack_sequence_as(output_classes, flat_ret)
|
||||
if isinstance(ret, Structure):
|
||||
if isinstance(ret, type_spec.TypeSpec):
|
||||
return ret
|
||||
else:
|
||||
return NestedStructure(ret)
|
||||
|
||||
|
||||
# NOTE(mrry): The following classes make extensive use of non-public methods of
|
||||
# their base class, so we disable the protected-access lint warning once here.
|
||||
# pylint: disable=protected-access
|
||||
# TODO(b/133606651) Update the tf.data code to use nests of TypeSpec rather
|
||||
# than NestedStructure; and then delete this class.
|
||||
@tf_export("data.experimental.NestedStructure")
|
||||
class NestedStructure(Structure):
|
||||
"""Represents a nested structure in which each leaf is a `Structure`."""
|
||||
class NestedStructure(type_spec.BatchableTypeSpec):
|
||||
"""Represents a nested structure in which each leaf is a `TypeSpec`."""
|
||||
|
||||
# NOTE(edloper): This class makes extensive use of non-public TypeSpec
|
||||
# methods, so we disable the protected-access lint warning once here.
|
||||
# pylint: disable=protected-access
|
||||
|
||||
__slots__ = ["_nested_structure", "_flat_nested_structure",
|
||||
"__flat_tensor_specs"]
|
||||
|
||||
def __init__(self, nested_structure):
|
||||
self._nested_structure = nested_structure
|
||||
self._flat_nested_structure = nest.flatten(nested_structure)
|
||||
self._flat_shapes_list = []
|
||||
self._flat_types_list = []
|
||||
for s in nest.flatten(nested_structure):
|
||||
if not isinstance(s, Structure):
|
||||
self.__flat_tensor_specs = []
|
||||
for s in self._flat_nested_structure:
|
||||
if not isinstance(s, type_spec.TypeSpec):
|
||||
raise TypeError("nested_structure must be a (potentially nested) tuple "
|
||||
"or dictionary of Structure objects.")
|
||||
self._flat_shapes_list.extend(s._flat_shapes)
|
||||
self._flat_types_list.extend(s._flat_types)
|
||||
"or dictionary of TypeSpec objects.")
|
||||
self.__flat_tensor_specs.extend(s._flat_tensor_specs)
|
||||
|
||||
value_type = property(lambda self: type(self._nested_structure))
|
||||
|
||||
def _serialize(self):
|
||||
return self._nested_structure
|
||||
|
||||
@classmethod
|
||||
def _deserialize(cls, nested_structure):
|
||||
return cls(nested_structure)
|
||||
|
||||
def most_specific_compatible_type(self, other):
|
||||
if type(self) is not type(other):
|
||||
raise ValueError("Incompatible types")
|
||||
return nest.map_structure(lambda a, b: a.most_specific_compatible_type(b),
|
||||
self._nested_structure, other._nested_structure)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, NestedStructure):
|
||||
return False
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
nest.assert_same_structure(self._nested_structure,
|
||||
other._nested_structure)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
return nest.flatten(self._nested_structure) == nest.flatten(
|
||||
other._nested_structure)
|
||||
return (nest.flatten(self._nested_structure) ==
|
||||
nest.flatten(other._nested_structure))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tuple(nest.flatten(self._nested_structure)))
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
return self._flat_shapes_list
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return self._flat_types_list
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
if not isinstance(other, NestedStructure):
|
||||
return False
|
||||
try:
|
||||
# pylint: disable=protected-access
|
||||
nest.assert_same_structure(self._nested_structure,
|
||||
other._nested_structure)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
return all(
|
||||
substructure.is_compatible_with(other_substructure)
|
||||
for substructure, other_substructure in zip(
|
||||
nest.flatten(self._nested_structure),
|
||||
nest.flatten(other._nested_structure)))
|
||||
|
||||
_component_specs = property(lambda self: self._nested_structure)
|
||||
_flat_tensor_specs = property(lambda self: self.__flat_tensor_specs)
|
||||
|
||||
def _to_components(self, value):
|
||||
return nest.map_structure_up_to(
|
||||
self._nested_structure, lambda t, v: t._to_components(v),
|
||||
self._nested_structure, value)
|
||||
|
||||
def _from_components(self, value):
|
||||
return nest.map_structure_up_to(
|
||||
self._nested_structure, lambda t, v: t._from_components(v),
|
||||
self._nested_structure, value)
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
ret = []
|
||||
|
||||
try:
|
||||
flat_value = nest.flatten_up_to(self._nested_structure, value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError("The value %r is not compatible with the nested "
|
||||
"structure %r." % (value, self._nested_structure))
|
||||
|
||||
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
|
||||
if not structure.is_compatible_with(Structure.from_value(sub_value)):
|
||||
raise ValueError("Component value %r is not compatible with the nested "
|
||||
"structure %r." % (sub_value, structure))
|
||||
ret.extend(structure._to_tensor_list(sub_value))
|
||||
return ret
|
||||
return self.__value_to_tensors(
|
||||
value, lambda struct, val: struct._to_tensor_list(val))
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
return self.__value_to_tensors(
|
||||
value, lambda struct, val: struct._to_batched_tensor_list(val))
|
||||
|
||||
def __value_to_tensors(self, value, to_tensor_list_fn):
|
||||
ret = []
|
||||
|
||||
try:
|
||||
@ -452,34 +258,31 @@ class NestedStructure(Structure):
|
||||
"structure %r." % (value, self._nested_structure))
|
||||
|
||||
for sub_value, structure in zip(flat_value, self._flat_nested_structure):
|
||||
if not structure.is_compatible_with(Structure.from_value(sub_value)):
|
||||
if not structure.is_compatible_with(
|
||||
type_spec.type_spec_from_value(sub_value)):
|
||||
raise ValueError("Component value %r is not compatible with the nested "
|
||||
"structure %r." % (sub_value, structure))
|
||||
ret.extend(structure._to_batched_tensor_list(sub_value))
|
||||
ret.extend(to_tensor_list_fn(structure, sub_value))
|
||||
return ret
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if len(flat_value) != len(self._flat_types):
|
||||
def _from_tensor_list(self, value):
|
||||
return self.__tensors_to_value(
|
||||
value, lambda struct, val: struct._from_tensor_list(val))
|
||||
|
||||
def _from_compatible_tensor_list(self, value):
|
||||
return self.__tensors_to_value(
|
||||
value, lambda struct, val: struct._from_compatible_tensor_list(val))
|
||||
|
||||
def __tensors_to_value(self, flat_value, from_tensor_list_fn):
|
||||
if len(flat_value) != len(self._flat_tensor_specs):
|
||||
raise ValueError("Expected %d flat values in NestedStructure but got %d."
|
||||
% (len(self._flat_types), len(flat_value)))
|
||||
|
||||
% (len(self._flat_tensor_specs), len(flat_value)))
|
||||
flat_ret = []
|
||||
i = 0
|
||||
for structure in self._flat_nested_structure:
|
||||
num_flat_values = len(structure._flat_types)
|
||||
num_flat_values = len(structure._flat_tensor_specs)
|
||||
sub_value = flat_value[i:i + num_flat_values]
|
||||
flat_ret.append(structure._from_tensor_list(sub_value))
|
||||
i += num_flat_values
|
||||
|
||||
return nest.pack_sequence_as(self._nested_structure, flat_ret)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
flat_ret = []
|
||||
i = 0
|
||||
for structure in self._flat_nested_structure:
|
||||
num_flat_values = len(structure._flat_types)
|
||||
sub_value = flat_value[i:i + num_flat_values]
|
||||
flat_ret.append(structure._from_compatible_tensor_list(sub_value))
|
||||
flat_ret.append(from_tensor_list_fn(structure, sub_value))
|
||||
i += num_flat_values
|
||||
|
||||
return nest.pack_sequence_as(self._nested_structure, flat_ret)
|
||||
@ -487,7 +290,8 @@ class NestedStructure(Structure):
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
flat_nested_structure = [
|
||||
Structure.from_value(sub_value) for sub_value in nest.flatten(value)
|
||||
type_spec.type_spec_from_value(sub_value)
|
||||
for sub_value in nest.flatten(value)
|
||||
]
|
||||
return NestedStructure(nest.pack_sequence_as(value, flat_nested_structure))
|
||||
|
||||
@ -512,362 +316,14 @@ class NestedStructure(Structure):
|
||||
lambda s: s._unbatch(), self._nested_structure))
|
||||
|
||||
|
||||
@tf_export("data.experimental.TensorStructure")
|
||||
class TensorStructure(Structure):
|
||||
"""Represents structural information about a `tf.Tensor`."""
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
tuple, NestedStructure.from_value, allow_subclass=True)
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
dict, NestedStructure.from_value, allow_subclass=True)
|
||||
|
||||
def __init__(self, dtype, shape):
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._shape = tensor_shape.as_shape(shape)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, TensorStructure) and tensor_spec.TensorSpec(
|
||||
self._shape, self._dtype) == tensor_spec.TensorSpec(
|
||||
other._shape, other._dtype))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tensor_spec.TensorSpec(self._shape, self._dtype))
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
return [self._shape]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [self._dtype]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
return (isinstance(other, TensorStructure) and
|
||||
self._dtype.is_compatible_with(other._dtype) and
|
||||
self._shape.is_compatible_with(other._shape))
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
if not self.is_compatible_with(Structure.from_value(value)):
|
||||
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
|
||||
"and shape %s." % (value, self._dtype, self._shape))
|
||||
return [value]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
if self._shape.merge_with(value.shape).ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return [value]
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if len(flat_value) != 1:
|
||||
raise ValueError("TensorStructure corresponds to a single tf.Tensor.")
|
||||
if not self.is_compatible_with(Structure.from_value(flat_value[0])):
|
||||
raise ValueError("Cannot convert %r to a tensor with dtype %s and shape "
|
||||
"%s." % (flat_value[0], self._dtype, self._shape))
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
|
||||
# op here and return that, instead of mutating the input's shape using
|
||||
# `Tensor.set_shape()`. However, that would add extra ops on the arguments
|
||||
# of each `tf.data` function, which could impact performance. When this
|
||||
# bug is resolved, we should be able to add the `ensure_shape()` ops and
|
||||
# optimize them away using contextual shape information.
|
||||
flat_value[0].set_shape(self._shape)
|
||||
return flat_value[0]
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
return TensorStructure(value.dtype, value.shape)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return ops.Tensor
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return TensorStructure(
|
||||
self._dtype,
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._shape))
|
||||
|
||||
def _unbatch(self):
|
||||
if self._shape.ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return TensorStructure(self._dtype, self._shape[1:])
|
||||
|
||||
|
||||
@tf_export("data.experimental.SparseTensorStructure")
|
||||
class SparseTensorStructure(Structure):
|
||||
"""Represents structural information about a `tf.SparseTensor`."""
|
||||
|
||||
def __init__(self, dtype, dense_shape):
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._dense_shape = tensor_shape.as_shape(dense_shape)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, SparseTensorStructure) and tensor_spec.TensorSpec(
|
||||
self._dense_shape, self._dtype) == tensor_spec.TensorSpec(
|
||||
other._dense_shape, other._dtype))
|
||||
|
||||
def __hash__(self):
|
||||
return hash(tensor_spec.TensorSpec(self._dense_shape, self._dtype))
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
# NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
|
||||
# but a `SparseTensorStructure` can also represent a batch of boxed
|
||||
# `SparseTensor` objects with shape `(?, 3)` (and batches of batches, etc.),
|
||||
# so the flat shape must be unknown.
|
||||
return [tensor_shape.unknown_shape(None)]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
return (isinstance(other, SparseTensorStructure) and
|
||||
self._dtype.is_compatible_with(other._dtype) and
|
||||
self._dense_shape.is_compatible_with(other._dense_shape))
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return [sparse_ops.serialize_sparse(value, out_type=dtypes.variant)]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
if self._dense_shape.merge_with(
|
||||
tensor_util.constant_value_as_shape(value.dense_shape)).ndims == 0:
|
||||
raise ValueError(
|
||||
"Unbatching a sparse tensor is only supported for rank >= 1")
|
||||
return [sparse_ops.serialize_many_sparse(value, out_type=dtypes.variant)]
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
|
||||
not flat_value[0].shape.is_compatible_with(tensor_shape.vector(3))):
|
||||
raise ValueError("SparseTensorStructure corresponds to a single "
|
||||
"tf.variant vector of length 3.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
ret = sparse_ops.deserialize_sparse(
|
||||
flat_value[0], dtype=self._dtype, rank=self._dense_shape.ndims)
|
||||
ret.indices.set_shape([None, self._dense_shape.ndims])
|
||||
ret.dense_shape.set_shape([self._dense_shape.ndims])
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
sparse_tensor = sparse_tensor_lib.SparseTensor.from_value(value)
|
||||
return SparseTensorStructure(
|
||||
sparse_tensor.dtype,
|
||||
tensor_util.constant_value_as_shape(sparse_tensor.dense_shape))
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._dense_shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return sparse_tensor_lib.SparseTensor
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return SparseTensorStructure(
|
||||
self._dtype,
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape))
|
||||
|
||||
def _unbatch(self):
|
||||
if self._dense_shape.ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return SparseTensorStructure(self._dtype, self._dense_shape[1:])
|
||||
|
||||
|
||||
@tf_export("data.experimental.TensorArrayStructure")
|
||||
class TensorArrayStructure(Structure):
|
||||
"""Represents structural information about a `tf.TensorArray`."""
|
||||
|
||||
def __init__(self, dtype, element_shape, dynamic_size, infer_shape):
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._element_shape = tensor_shape.as_shape(element_shape)
|
||||
self._dynamic_size = dynamic_size
|
||||
self._infer_shape = infer_shape
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, TensorArrayStructure) and tensor_spec.TensorSpec(
|
||||
self._element_shape, self._dtype) == tensor_spec.TensorSpec(
|
||||
other._element_shape, other._dtype) and
|
||||
self._dynamic_size == other._dynamic_size and
|
||||
self._infer_shape == other._infer_shape)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((tensor_spec.TensorSpec(self._element_shape, self._dtype),
|
||||
self._dynamic_size, self._infer_shape))
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
# A TensorArray is represented via its variant object, which is a scalar.
|
||||
return [tensor_shape.scalar()]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
return (isinstance(other, TensorArrayStructure) and
|
||||
self._dtype.is_compatible_with(other._dtype) and
|
||||
self._element_shape.is_compatible_with(other._element_shape) and
|
||||
self._dynamic_size == other._dynamic_size)
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
if not isinstance(value, tensor_array_ops.TensorArray):
|
||||
raise TypeError("value must be a TensorArray, but saw: {}"
|
||||
.format(type(value)))
|
||||
if value.flow is not None and value.flow.dtype == dtypes.variant:
|
||||
return [value.flow]
|
||||
else:
|
||||
# Convert to a TF2-style TensorArray.
|
||||
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
|
||||
# "implementation / as_variant" arg to TensorArray constructor.
|
||||
with ops.name_scope("convert_tensor_array"):
|
||||
flow = list_ops.tensor_list_from_tensor(
|
||||
tensor=value.stack(), element_shape=value.element_shape)
|
||||
return [flow]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
raise NotImplementedError("TensorArrayStructure._to_batched_tensor_list")
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant or
|
||||
not flat_value[0].shape.is_compatible_with(tensor_shape.scalar())):
|
||||
raise ValueError("TensorArrayStructure corresponds to a single "
|
||||
"tf.variant scalar.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
# This will return a TF2 Graph-style TensorArray because flat_value[0] is
|
||||
# a variant object. size == -1 implies unknown size.
|
||||
ret = tensor_array_ops.TensorArray(
|
||||
dtype=self._dtype,
|
||||
flow=flat_value[0],
|
||||
dynamic_size=self._dynamic_size,
|
||||
infer_shape=self._infer_shape)
|
||||
ret._element_shape = [self._element_shape]
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
if not isinstance(value, tensor_array_ops.TensorArray):
|
||||
raise TypeError("Expected value to be a TensorArray, but saw: {}".
|
||||
format(type(value)))
|
||||
|
||||
return TensorArrayStructure(
|
||||
dtype=value.dtype,
|
||||
element_shape=value.element_shape,
|
||||
dynamic_size=value.dynamic_size,
|
||||
infer_shape=value._infer_shape)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
# Sneak the dynamic_size and infer_shape values into the legacy shape.
|
||||
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
|
||||
.concatenate(self._element_shape))
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return tensor_array_ops.TensorArray
|
||||
|
||||
def _batch(self, batch_size):
|
||||
raise NotImplementedError("TensorArrayStructure._batch")
|
||||
|
||||
def _unbatch(self):
|
||||
raise NotImplementedError("TensorArrayStructure._unbatch")
|
||||
|
||||
|
||||
@tf_export("data.experimental.RaggedTensorStructure")
|
||||
class RaggedTensorStructure(Structure):
|
||||
"""Represents structural information about a `tf.RaggedTensor`."""
|
||||
|
||||
def __init__(self, dtype, shape, ragged_rank):
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._shape = tensor_shape.as_shape(shape)
|
||||
self._ragged_rank = ragged_rank
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, RaggedTensorStructure) and tensor_spec.TensorSpec(
|
||||
self._shape, self._dtype) == tensor_spec.TensorSpec(
|
||||
other._shape, other._dtype) and
|
||||
self._ragged_rank == other._ragged_rank)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((tensor_spec.TensorSpec(self._shape, self._dtype),
|
||||
self._ragged_rank))
|
||||
|
||||
@property
|
||||
def _flat_shapes(self):
|
||||
# A list of shapes matching the shapes of `self._to_tensor_list()`.
|
||||
# NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
|
||||
# `[]` (scalar), but a `RaggedTensorStructure` can also represent a batch of
|
||||
# boxed `RaggedTensor` objects with shape `(?)` (and batches of batches,
|
||||
# etc.), so the flat shape must be unknown.
|
||||
return [tensor_shape.unknown_shape(None)]
|
||||
|
||||
@property
|
||||
def _flat_types(self):
|
||||
return [dtypes.variant]
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
return (isinstance(other, RaggedTensorStructure) and
|
||||
self._dtype.is_compatible_with(other._dtype) and
|
||||
self._shape.is_compatible_with(other._shape) and
|
||||
self._ragged_rank == other._ragged_rank)
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
return [value._to_variant()]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
return [value._to_variant(batched_input=True)]
|
||||
|
||||
def _from_tensor_list(self, flat_value):
|
||||
if (len(flat_value) != 1 or flat_value[0].dtype != dtypes.variant):
|
||||
raise ValueError("RaggedTensorStructure corresponds to a single "
|
||||
"tf.variant scalar.")
|
||||
return self._from_compatible_tensor_list(flat_value)
|
||||
|
||||
def _from_compatible_tensor_list(self, flat_value):
|
||||
if self._ragged_rank <= 0:
|
||||
raise ValueError(
|
||||
"ragged_rank must be greater than zero. Found ragged_rank: %d" %
|
||||
self._ragged_rank)
|
||||
result = ragged_tensor.RaggedTensor._from_variant(
|
||||
flat_value[0], dtype=self._dtype, output_ragged_rank=self._ragged_rank)
|
||||
if self._shape.ndims is not None:
|
||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||
if outer_dim is not None:
|
||||
result.row_splits.set_shape([outer_dim + 1])
|
||||
result.flat_values.set_shape(
|
||||
tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[1 + self._ragged_rank:]))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
return RaggedTensorStructure(value.dtype, value.shape, value.ragged_rank)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return self
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return RaggedTensorStructure(
|
||||
self._dtype,
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
|
||||
self._ragged_rank + 1)
|
||||
|
||||
def _unbatch(self):
|
||||
# Note: Any ragged_rank is allowed here because the dataset could be
|
||||
# subsequently batched again. Errors are handled in
|
||||
# RaggedTensorStructure._from_compatible_tensor_list()
|
||||
return RaggedTensorStructure(self._dtype, self._shape[1:],
|
||||
self._ragged_rank - 1)
|
||||
# Re-register SparseTensorValue -- it's a subclass of tuple, but we don't
|
||||
# want the NestedStructure registration to take precedence.
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
sparse_tensor_lib.SparseTensorValue,
|
||||
sparse_tensor_lib.SparseTensorSpec.from_value)
|
||||
|
@ -29,7 +29,9 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import tensor_array_ops
|
||||
from tensorflow.python.ops import variables
|
||||
@ -50,16 +52,16 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
|
||||
# pylint: disable=g-long-lambda,protected-access
|
||||
@parameterized.named_parameters(
|
||||
("Tensor", lambda: constant_op.constant(37.0), structure.TensorStructure,
|
||||
("Tensor", lambda: constant_op.constant(37.0), tensor_spec.TensorSpec,
|
||||
[dtypes.float32], [[]]),
|
||||
("TensorArray", lambda: tensor_array_ops.TensorArray(
|
||||
dtype=dtypes.float32, element_shape=(3,), size=0),
|
||||
structure.TensorArrayStructure, [dtypes.variant], [None, 3]),
|
||||
tensor_array_ops.TensorArraySpec, [dtypes.variant], [[]]),
|
||||
("SparseTensor", lambda: sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]),
|
||||
structure.SparseTensorStructure, [dtypes.variant], [None]),
|
||||
sparse_tensor.SparseTensorSpec, [dtypes.variant], [None]),
|
||||
("RaggedTensor", lambda: ragged_factory_ops.constant([[1, 2], [], [4]]),
|
||||
structure.RaggedTensorStructure, [dtypes.variant], [None]),
|
||||
ragged_tensor.RaggedTensorSpec, [dtypes.variant], [None]),
|
||||
("Nested_0",
|
||||
lambda: (constant_op.constant(37.0), constant_op.constant([1, 2, 3])),
|
||||
structure.NestedStructure, [dtypes.float32, dtypes.int32], [[], [3]]),
|
||||
@ -80,13 +82,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
def testFlatStructure(self, value_fn, expected_structure, expected_types,
|
||||
expected_shapes):
|
||||
value = value_fn()
|
||||
s = structure.Structure.from_value(value)
|
||||
s = type_spec.type_spec_from_value(value)
|
||||
self.assertIsInstance(s, expected_structure)
|
||||
self.assertEqual(expected_types, s._flat_types)
|
||||
self.assertLen(s._flat_shapes, len(expected_shapes))
|
||||
for expected, actual in zip(expected_shapes, s._flat_shapes):
|
||||
self.assertTrue(actual.is_compatible_with(expected))
|
||||
self.assertTrue(
|
||||
tensor_shape.as_shape(expected).is_compatible_with(actual))
|
||||
if expected is None:
|
||||
self.assertEqual(actual.ndims, None)
|
||||
else:
|
||||
self.assertEqual(actual.as_list(), expected)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Tensor", lambda: constant_op.constant(37.0), lambda: [
|
||||
@ -162,15 +166,15 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
original_value = original_value_fn()
|
||||
compatible_values = compatible_values_fn()
|
||||
incompatible_values = incompatible_values_fn()
|
||||
s = structure.Structure.from_value(original_value)
|
||||
s = type_spec.type_spec_from_value(original_value)
|
||||
for compatible_value in compatible_values:
|
||||
self.assertTrue(
|
||||
s.is_compatible_with(
|
||||
structure.Structure.from_value(compatible_value)))
|
||||
type_spec.type_spec_from_value(compatible_value)))
|
||||
for incompatible_value in incompatible_values:
|
||||
self.assertFalse(
|
||||
s.is_compatible_with(
|
||||
structure.Structure.from_value(incompatible_value)))
|
||||
type_spec.type_spec_from_value(incompatible_value)))
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Tensor",
|
||||
@ -215,8 +219,8 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
def testStructureFromValueEquality(self, value1_fn, value2_fn,
|
||||
*not_equal_value_fns):
|
||||
# pylint: disable=g-generic-assert
|
||||
s1 = structure.Structure.from_value(value1_fn())
|
||||
s2 = structure.Structure.from_value(value2_fn())
|
||||
s1 = type_spec.type_spec_from_value(value1_fn())
|
||||
s2 = type_spec.type_spec_from_value(value2_fn())
|
||||
self.assertEqual(s1, s1) # check __eq__ operator.
|
||||
self.assertEqual(s1, s2) # check __eq__ operator.
|
||||
self.assertFalse(s1 != s1) # check __ne__ operator.
|
||||
@ -224,7 +228,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
self.assertEqual(hash(s1), hash(s1))
|
||||
self.assertEqual(hash(s1), hash(s2))
|
||||
for value_fn in not_equal_value_fns:
|
||||
s3 = structure.Structure.from_value(value_fn())
|
||||
s3 = type_spec.type_spec_from_value(value_fn())
|
||||
self.assertNotEqual(s1, s3) # check __ne__ operator.
|
||||
self.assertNotEqual(s2, s3) # check __ne__ operator.
|
||||
self.assertFalse(s1 == s3) # check __eq_ operator.
|
||||
@ -272,9 +276,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
}),
|
||||
)
|
||||
def testHash(self, value1_fn, value2_fn, value3_fn):
|
||||
s1 = structure.Structure.from_value(value1_fn())
|
||||
s2 = structure.Structure.from_value(value2_fn())
|
||||
s3 = structure.Structure.from_value(value3_fn())
|
||||
s1 = type_spec.type_spec_from_value(value1_fn())
|
||||
s2 = type_spec.type_spec_from_value(value2_fn())
|
||||
s3 = type_spec.type_spec_from_value(value3_fn())
|
||||
self.assertEqual(hash(s1), hash(s1))
|
||||
self.assertEqual(hash(s1), hash(s2))
|
||||
self.assertNotEqual(hash(s1), hash(s3))
|
||||
@ -314,7 +318,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
)
|
||||
def testRoundTripConversion(self, value_fn):
|
||||
value = value_fn()
|
||||
s = structure.Structure.from_value(value)
|
||||
s = type_spec.type_spec_from_value(value)
|
||||
|
||||
def maybe_stack_ta(v):
|
||||
if isinstance(v, tensor_array_ops.TensorArray):
|
||||
@ -344,7 +348,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
|
||||
def preserveStaticShape(self):
|
||||
rt = ragged_factory_ops.constant([[1, 2], [], [3]])
|
||||
rt_s = structure.Structure.from_value(rt)
|
||||
rt_s = type_spec.type_spec_from_value(rt)
|
||||
rt_after = rt_s._from_tensor_list(rt_s._to_tensor_list(rt))
|
||||
self.assertEqual(rt_after.row_splits.shape.as_list(),
|
||||
rt.row_splits.shape.as_list())
|
||||
@ -352,7 +356,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
|
||||
st = sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5])
|
||||
st_s = structure.Structure.from_value(st)
|
||||
st_s = type_spec.type_spec_from_value(st)
|
||||
st_after = st_s._from_tensor_list(st_s._to_tensor_list(st))
|
||||
self.assertEqual(st_after.indices.shape.as_list(),
|
||||
[None, 2])
|
||||
@ -367,19 +371,19 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
# 2. Using one structure to restructre a flattened value with an
|
||||
# incompatible structure fails.
|
||||
value_tensor = constant_op.constant(42.0)
|
||||
s_tensor = structure.Structure.from_value(value_tensor)
|
||||
s_tensor = type_spec.type_spec_from_value(value_tensor)
|
||||
flat_tensor = s_tensor._to_tensor_list(value_tensor)
|
||||
|
||||
value_sparse_tensor = sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0]], values=[1], dense_shape=[1, 1])
|
||||
s_sparse_tensor = structure.Structure.from_value(value_sparse_tensor)
|
||||
s_sparse_tensor = type_spec.type_spec_from_value(value_sparse_tensor)
|
||||
flat_sparse_tensor = s_sparse_tensor._to_tensor_list(value_sparse_tensor)
|
||||
|
||||
value_nest = {
|
||||
"a": constant_op.constant(37.0),
|
||||
"b": constant_op.constant([1, 2, 3])
|
||||
}
|
||||
s_nest = structure.Structure.from_value(value_nest)
|
||||
s_nest = type_spec.type_spec_from_value(value_nest)
|
||||
flat_nest = s_nest._to_tensor_list(value_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
@ -391,38 +395,34 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
r"dtype.*float32.* and shape \(\)"):
|
||||
s_tensor._to_tensor_list(value_nest)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Neither a SparseTensor nor SparseTensorValue"):
|
||||
s_sparse_tensor._to_tensor_list(value_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "Input must be a SparseTensor"):
|
||||
with self.assertRaisesRegexp(
|
||||
TypeError, "Neither a SparseTensor nor SparseTensorValue"):
|
||||
s_sparse_tensor._to_tensor_list(value_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Tensor.* not compatible with the nested structure "
|
||||
".*TensorStructure.*TensorStructure"):
|
||||
".*TensorSpec.*TensorSpec"):
|
||||
s_nest._to_tensor_list(value_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensor.* not compatible with the nested structure "
|
||||
".*TensorStructure.*TensorStructure"):
|
||||
".*TensorSpec.*TensorSpec"):
|
||||
s_nest._to_tensor_list(value_sparse_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Cannot convert.*with dtype.*float32.* and shape \(\)"):
|
||||
with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
|
||||
s_tensor._from_tensor_list(flat_sparse_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "TensorStructure corresponds to a single tf.Tensor."):
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
|
||||
s_tensor._from_tensor_list(flat_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensorStructure corresponds to a single tf.variant "
|
||||
"vector of length 3."):
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
|
||||
s_sparse_tensor._from_tensor_list(flat_tensor)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensorStructure corresponds to a single tf.variant "
|
||||
"vector of length 3."):
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
|
||||
s_sparse_tensor._from_tensor_list(flat_nest)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
@ -445,7 +445,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
"a": constant_op.constant(37.0),
|
||||
"b": constant_op.constant([1, 2, 3])
|
||||
}
|
||||
s_0 = structure.Structure.from_value(value_0)
|
||||
s_0 = type_spec.type_spec_from_value(value_0)
|
||||
flat_s_0 = s_0._to_tensor_list(value_0)
|
||||
|
||||
# `value_1` has compatible nested structure with `value_0`, but different
|
||||
@ -457,7 +457,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[0, 0]], values=[1], dense_shape=[1, 1])
|
||||
}
|
||||
s_1 = structure.Structure.from_value(value_1)
|
||||
s_1 = type_spec.type_spec_from_value(value_1)
|
||||
flat_s_1 = s_1._to_tensor_list(value_1)
|
||||
|
||||
# `value_2` has incompatible nested structure with `value_0` and `value_1`.
|
||||
@ -469,27 +469,27 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
sparse_tensor.SparseTensor(
|
||||
indices=[[3, 4]], values=[-1], dense_shape=[4, 5]))
|
||||
}
|
||||
s_2 = structure.Structure.from_value(value_2)
|
||||
s_2 = type_spec.type_spec_from_value(value_2)
|
||||
flat_s_2 = s_2._to_tensor_list(value_2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensor.* not compatible with the nested structure "
|
||||
".*TensorStructure"):
|
||||
ValueError, ".*SparseTensor.* not compatible with the nested structure "
|
||||
".*TensorSpec"):
|
||||
s_0._to_tensor_list(value_1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
|
||||
"nested structure .*TensorStructure"):
|
||||
ValueError, ".*SparseTensor.*SparseTensor.* not compatible with the "
|
||||
"nested structure .*TensorSpec"):
|
||||
s_0._to_tensor_list(value_2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Tensor.* not compatible with the nested structure "
|
||||
".*SparseTensorStructure"):
|
||||
ValueError, ".*Tensor.* not compatible with the nested structure "
|
||||
".*SparseTensorSpec"):
|
||||
s_1._to_tensor_list(value_0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensor.*SparseTensor.* not compatible with the "
|
||||
"nested structure .*TensorStructure"):
|
||||
ValueError, ".*SparseTensor.*SparseTensor.* not compatible with the "
|
||||
"nested structure .*TensorSpec"):
|
||||
s_0._to_tensor_list(value_2)
|
||||
|
||||
# NOTE(mrry): The repr of the dictionaries is not sorted, so the regexp
|
||||
@ -497,29 +497,27 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
# adding a deterministic repr for these error messages (among other
|
||||
# improvements).
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Tensor.*Tensor.* not compatible with the nested structure "
|
||||
".*(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
|
||||
"SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
|
||||
ValueError,
|
||||
".*Tensor.*Tensor.* not compatible with the nested structure "
|
||||
".*(TensorSpec.*SparseTensorSpec.*SparseTensorSpec|"
|
||||
"SparseTensorSpec.*SparseTensorSpec.*TensorSpec)"):
|
||||
s_2._to_tensor_list(value_0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "(Tensor.*SparseTensor|SparseTensor.*Tensor).* "
|
||||
"not compatible with the nested structure .*"
|
||||
"(TensorStructure.*SparseTensorStructure.*SparseTensorStructure|"
|
||||
"SparseTensorStructure.*SparseTensorStructure.*TensorStructure)"):
|
||||
"(TensorSpec.*SparseTensorSpec.*SparseTensorSpec|"
|
||||
"SparseTensorSpec.*SparseTensorSpec.*TensorSpec)"):
|
||||
s_2._to_tensor_list(value_1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r"Cannot convert.*with dtype.*int32.* and shape \(3,\)"):
|
||||
with self.assertRaisesRegexp(ValueError, r"Incompatible input:"):
|
||||
s_0._from_tensor_list(flat_s_1)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "Expected 2 flat values in NestedStructure but got 3."):
|
||||
s_0._from_tensor_list(flat_s_2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, "SparseTensorStructure corresponds to a single tf.variant "
|
||||
"vector of length 3."):
|
||||
with self.assertRaisesRegexp(ValueError, "Incompatible input: "):
|
||||
s_1._from_tensor_list(flat_s_0)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
@ -552,9 +550,9 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
[True, False, 2, 2]), tensor_array_ops.TensorArray,
|
||||
structure.TensorArrayStructure(
|
||||
dtypes.int32, [2, 2], dynamic_size=True, infer_shape=False)),
|
||||
("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, 2),
|
||||
structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1),
|
||||
structure.RaggedTensorStructure(dtypes.int32, [2, 2], 1)),
|
||||
("RaggedTensor", dtypes.int32, tensor_shape.matrix(2, None),
|
||||
structure.RaggedTensorStructure(dtypes.int32, [2, None], 1),
|
||||
structure.RaggedTensorStructure(dtypes.int32, [2, None], 1)),
|
||||
("Nested", {
|
||||
"a": dtypes.float32,
|
||||
"b": (dtypes.int32, dtypes.string)
|
||||
@ -576,8 +574,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
output_classes, expected_structure):
|
||||
actual_structure = structure.convert_legacy_structure(
|
||||
output_types, output_shapes, output_classes)
|
||||
self.assertTrue(expected_structure.is_compatible_with(actual_structure))
|
||||
self.assertTrue(actual_structure.is_compatible_with(expected_structure))
|
||||
self.assertEqual(actual_structure, expected_structure)
|
||||
|
||||
def testNestedNestedStructure(self):
|
||||
# Although `Structure.from_value()` will not construct one, a nested
|
||||
@ -639,10 +636,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
def testBatch(self, element_structure, batch_size,
|
||||
expected_batched_structure):
|
||||
batched_structure = element_structure._batch(batch_size)
|
||||
self.assertTrue(
|
||||
batched_structure.is_compatible_with(expected_batched_structure))
|
||||
self.assertTrue(
|
||||
expected_batched_structure.is_compatible_with(batched_structure))
|
||||
self.assertEqual(batched_structure, expected_batched_structure)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("Tensor", structure.TensorStructure(dtypes.float32, [32]),
|
||||
@ -656,11 +650,11 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
structure.SparseTensorStructure(dtypes.float32, [None, 4]),
|
||||
structure.SparseTensorStructure(dtypes.float32, [4])),
|
||||
("RaggedTensor",
|
||||
structure.RaggedTensorStructure(dtypes.float32, [32, 4, None], 2),
|
||||
structure.RaggedTensorStructure(dtypes.float32, [4, None], 1)),
|
||||
structure.RaggedTensorStructure(dtypes.float32, [32, None, None], 2),
|
||||
structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
|
||||
("RaggedTensorUnknown",
|
||||
structure.RaggedTensorStructure(dtypes.float32, [None, None, 4], 2),
|
||||
structure.RaggedTensorStructure(dtypes.float32, [None, 4], 1)),
|
||||
structure.RaggedTensorStructure(dtypes.float32, [None, None, None], 2),
|
||||
structure.RaggedTensorStructure(dtypes.float32, [None, None], 1)),
|
||||
("Nested", structure.NestedStructure({
|
||||
"a": structure.TensorStructure(dtypes.float32, [128]),
|
||||
"b": (structure.SparseTensorStructure(dtypes.int32, [128, 2, 2]),
|
||||
@ -672,10 +666,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
)
|
||||
def testUnbatch(self, element_structure, expected_unbatched_structure):
|
||||
unbatched_structure = element_structure._unbatch()
|
||||
self.assertTrue(
|
||||
unbatched_structure.is_compatible_with(expected_unbatched_structure))
|
||||
self.assertTrue(
|
||||
expected_unbatched_structure.is_compatible_with(unbatched_structure))
|
||||
self.assertEqual(unbatched_structure, expected_unbatched_structure)
|
||||
|
||||
# pylint: disable=g-long-lambda
|
||||
@parameterized.named_parameters(
|
||||
@ -697,7 +688,7 @@ class StructureTest(test_base.DatasetTestBase, parameterized.TestCase,
|
||||
)
|
||||
def testToBatchedTensorList(self, value_fn, element_0_fn):
|
||||
batched_value = value_fn()
|
||||
s = structure.Structure.from_value(batched_value)
|
||||
s = type_spec.type_spec_from_value(batched_value)
|
||||
batched_tensor_list = s._to_batched_tensor_list(batched_value)
|
||||
|
||||
# The batch dimension is 2 for all of the test cases.
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import tensor_conversion_registry
|
||||
from tensorflow.python.framework import tensor_like
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -175,6 +176,66 @@ IndexedSlicesValue = collections.namedtuple(
|
||||
"IndexedSlicesValue", ["values", "indices", "dense_shape"])
|
||||
|
||||
|
||||
# TODO(b/133606651) Export this as tf.IndexedSlicesSpec.
|
||||
class IndexedSlicesSpec(type_spec.TypeSpec):
|
||||
"""Type specification for a `tf.IndexedSlices`."""
|
||||
|
||||
__slots__ = ["_shape", "_values_dtype", "_indices_dtype",
|
||||
"_dense_shape_dtype", "_indices_shape"]
|
||||
|
||||
value_type = property(lambda self: IndexedSlices)
|
||||
|
||||
def __init__(self, shape=None, dtype=dtypes.float32,
|
||||
indices_dtype=dtypes.int64, dense_shape_dtype=True,
|
||||
indices_shape=None):
|
||||
"""Constructs a type specification for a `tf.IndexedSlices`.
|
||||
|
||||
Args:
|
||||
shape: The dense shape of the `IndexedSlices`, or `None` to allow any
|
||||
dense shape.
|
||||
dtype: `tf.DType` of values in the `IndexedSlices`.
|
||||
indices_dtype: `tf.DType` of the `indices` in the `IndexedSlices`. One
|
||||
of `tf.int32` or `tf.int64`.
|
||||
dense_shape_dtype: `tf.DType` of the `dense_shape` in the `IndexedSlices`.
|
||||
One of `tf.int32`, `tf.int64`, or `None` (if the `IndexedSlices` has
|
||||
no `dense_shape` tensor).
|
||||
indices_shape: The shape of the `indices` component, which indicates
|
||||
how many slices are in the `IndexedSlices`.
|
||||
"""
|
||||
self._shape = tensor_shape.as_shape(shape)
|
||||
self._values_dtype = dtypes.as_dtype(dtype)
|
||||
self._indices_dtype = dtypes.as_dtype(indices_dtype)
|
||||
if dense_shape_dtype is None:
|
||||
self._dense_shape_dtype = None
|
||||
else:
|
||||
self._dense_shape_dtype = dtypes.as_dtype(dense_shape_dtype)
|
||||
self._indices_shape = tensor_shape.as_shape(indices_shape)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._values_dtype, self._indices_dtype,
|
||||
self._dense_shape_dtype, self._indices_shape)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
value_shape = self._indices_shape.concatenate(self._shape[1:])
|
||||
specs = [
|
||||
tensor_spec.TensorSpec(value_shape, self._values_dtype),
|
||||
tensor_spec.TensorSpec(self._indices_shape, self._indices_dtype)]
|
||||
if self._dense_shape_dtype is not None:
|
||||
specs.append(
|
||||
tensor_spec.TensorSpec([self._shape.ndims], self._dense_shape_dtype))
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
if value.dense_shape is None:
|
||||
return (value.values, value.indices)
|
||||
else:
|
||||
return (value.values, value.indices, value.dense_shape)
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
return IndexedSlices(*tensor_list)
|
||||
|
||||
|
||||
@tf_export(v1=["convert_to_tensor_or_indexed_slices"])
|
||||
def convert_to_tensor_or_indexed_slices(value, dtype=None, name=None):
|
||||
"""Converts the given object to a `Tensor` or an `IndexedSlices`.
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
@ -28,6 +29,8 @@ from tensorflow.python.framework import tensor_like
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import gen_sparse_ops
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@ -117,6 +120,7 @@ class SparseTensor(_TensorLike, composite_tensor.CompositeTensor):
|
||||
dense_shape: A 1-D int64 tensor of shape `[ndims]`.
|
||||
"""
|
||||
if isinstance(indices, tensor_spec.TensorSpec):
|
||||
# TODO(b/133606651) Remove this code path -- replaced by TypeSpec.
|
||||
if not isinstance(values, tensor_spec.TensorSpec):
|
||||
raise TypeError("Expected values to be a TensorSpec")
|
||||
if not isinstance(dense_shape, tensor_spec.TensorSpec):
|
||||
@ -271,6 +275,120 @@ tf_export(v1=["SparseTensorValue"])(SparseTensorValue)
|
||||
pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
|
||||
|
||||
|
||||
# TODO(b/133606651) Export this as tf.SparseTensorSpec.
|
||||
class SparseTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for a `tf.SparseTensor`."""
|
||||
|
||||
__slots__ = ["_dense_shape", "_dtype"]
|
||||
|
||||
value_type = property(lambda self: SparseTensor)
|
||||
|
||||
def __init__(self, dense_shape=None, dtype=dtypes.float32):
|
||||
"""Constructs a type specification for a `tf.SparseTensor`.
|
||||
|
||||
Args:
|
||||
dense_shape: The dense shape of the `SparseTensor`, or `None` to allow
|
||||
any dense shape.
|
||||
dtype: `tf.DType` of values in the `SparseTensor`.
|
||||
"""
|
||||
self._dense_shape = tensor_shape.as_shape(dense_shape)
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._dense_shape, self._dtype)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
rank = self._dense_shape.ndims
|
||||
num_values = None
|
||||
return [
|
||||
tensor_spec.TensorSpec([num_values, rank], dtypes.int64),
|
||||
tensor_spec.TensorSpec([num_values], self._dtype),
|
||||
tensor_spec.TensorSpec([rank], dtypes.int64)]
|
||||
|
||||
def _to_components(self, value):
|
||||
if isinstance(value, SparseTensorValue):
|
||||
value = SparseTensor.from_value(value)
|
||||
return [value.indices, value.values, value.dense_shape]
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
return SparseTensor(*tensor_list)
|
||||
|
||||
# The SparseTensorSpec tensor_list encoding uses (de)serialize_sparse ops
|
||||
# to (un)box the component tensors in a way that allows for batching &
|
||||
# unbatching.
|
||||
@property
|
||||
def _flat_tensor_specs(self):
|
||||
# NOTE(mrry): The default flat shape of a boxed `SparseTensor` is `(3,)`,
|
||||
# but a `SparseTensorSpec` can also represent a batch of boxed
|
||||
# `SparseTensor` objects with shape `(..., 3)` (and batches of batches,
|
||||
# etc.), so the flat shape must be unknown.
|
||||
return [tensor_spec.TensorSpec(None, dtypes.variant)]
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
value = SparseTensor.from_value(value)
|
||||
return [gen_sparse_ops.serialize_sparse(
|
||||
value.indices, value.values, value.dense_shape,
|
||||
out_type=dtypes.variant)]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
dense_shape = tensor_util.constant_value_as_shape(value.dense_shape)
|
||||
if self._dense_shape.merge_with(dense_shape).ndims == 0:
|
||||
raise ValueError(
|
||||
"Unbatching a sparse tensor is only supported for rank >= 1")
|
||||
return [gen_sparse_ops.serialize_many_sparse(
|
||||
value.indices, value.values, value.dense_shape,
|
||||
out_type=dtypes.variant)]
|
||||
|
||||
def _from_compatible_tensor_list(self, tensor_list):
|
||||
tensor_list = gen_sparse_ops.deserialize_sparse(tensor_list[0], self._dtype)
|
||||
result = SparseTensor(*tensor_list)
|
||||
rank = self._dense_shape.ndims
|
||||
result.indices.set_shape([None, rank])
|
||||
result.dense_shape.set_shape([rank])
|
||||
return result
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return SparseTensorSpec(
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._dense_shape),
|
||||
self._dtype)
|
||||
|
||||
def _unbatch(self):
|
||||
if self._dense_shape.ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return SparseTensorSpec(self._dense_shape[1:], self._dtype)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._dense_shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return SparseTensor
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
if isinstance(value, SparseTensor):
|
||||
return cls(value.shape, value.dtype)
|
||||
if isinstance(value, SparseTensorValue):
|
||||
if isinstance(value.values, np.ndarray):
|
||||
return cls(value.dense_shape, value.values.dtype)
|
||||
else:
|
||||
return cls.from_value(SparseTensor.from_value(value))
|
||||
else:
|
||||
raise TypeError("Expected SparseTensor or SparseTensorValue")
|
||||
|
||||
|
||||
# TODO(b/133606651) Delete the SparseTensor registration when CompositeTensor
|
||||
# is updated to define a _type_spec field (since registration will be
|
||||
# automatic). Do *not* delete the SparseTensorValue registration.
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
SparseTensor, SparseTensorSpec.from_value)
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
SparseTensorValue, SparseTensorSpec.from_value)
|
||||
|
||||
|
||||
@tf_export(v1=["convert_to_tensor_or_sparse_tensor"])
|
||||
def convert_to_tensor_or_sparse_tensor(value, dtype=None, name=None):
|
||||
"""Converts value to a `SparseTensor` or `Tensor`.
|
||||
|
@ -20,15 +20,17 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import common_shapes
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@tf_export("TensorSpec")
|
||||
class TensorSpec(object):
|
||||
class TensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Describes a tf.Tensor.
|
||||
|
||||
Metadata for describing the `tf.Tensor` objects accepted or returned
|
||||
@ -97,7 +99,8 @@ class TensorSpec(object):
|
||||
Returns:
|
||||
True if spec_or_tensor is compatible with self.
|
||||
"""
|
||||
return (self._dtype.is_compatible_with(spec_or_tensor.dtype) and
|
||||
return (isinstance(spec_or_tensor, (TensorSpec, ops.Tensor)) and
|
||||
self._dtype.is_compatible_with(spec_or_tensor.dtype) and
|
||||
self._shape.is_compatible_with(spec_or_tensor.shape))
|
||||
|
||||
def __repr__(self):
|
||||
@ -108,17 +111,80 @@ class TensorSpec(object):
|
||||
return hash((self._shape_tuple, self.dtype))
|
||||
|
||||
def __eq__(self, other):
|
||||
return (self._shape_tuple == other._shape_tuple # pylint: disable=protected-access
|
||||
and self.dtype == other.dtype
|
||||
and self._name == other._name) # pylint: disable=protected-access
|
||||
# pylint: disable=protected-access
|
||||
return (type(self) is type(other) and
|
||||
self._shape_tuple == other._shape_tuple
|
||||
and self._dtype == other._dtype
|
||||
and self._name == other._name)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
def __reduce__(self):
|
||||
return TensorSpec, (self._shape, self._dtype, self._name)
|
||||
value_type = property(lambda self: ops.Tensor)
|
||||
|
||||
def most_specific_compatible_type(self, other):
|
||||
if (type(self) is not type(other)) or (self._dtype != other.dtype):
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (self, other))
|
||||
shape = self._shape.most_specific_compatible_shape(other.shape)
|
||||
name = self._name if self._name == other.name else None
|
||||
return TensorSpec(shape, self._dtype, name)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype, self._name)
|
||||
|
||||
_component_specs = property(lambda self: self)
|
||||
|
||||
def _to_components(self, value):
|
||||
try:
|
||||
value = ops.convert_to_tensor(value, self._dtype)
|
||||
except (TypeError, ValueError):
|
||||
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
|
||||
"and shape %s." % (value, self._dtype, self._shape))
|
||||
if not value.shape.is_compatible_with(self._shape):
|
||||
raise ValueError("Value %r is not convertible to a tensor with dtype %s "
|
||||
"and shape %s." % (value, self._dtype, self._shape))
|
||||
return value
|
||||
|
||||
def _from_components(self, components):
|
||||
return components
|
||||
|
||||
def _from_compatible_tensor_list(self, tensor_list):
|
||||
# TODO(b/112266545): It would be cleaner to create a new `ensure_shape()`
|
||||
# op here and return that, instead of mutating the input's shape using
|
||||
# `Tensor.set_shape()`. However, that would add extra ops, which could
|
||||
# impact performance. When this bug is resolved, we should be able to add
|
||||
# the `ensure_shape()` ops and optimize them away using contextual shape
|
||||
# information.
|
||||
assert len(tensor_list) == 1
|
||||
tensor_list[0].set_shape(self._shape)
|
||||
return tensor_list[0]
|
||||
|
||||
def _to_batchable_tensor_list(self, value, batched=False):
|
||||
if batched and self._shape.merge_with(value.shape).ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return self._to_components(value)
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return TensorSpec(
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
|
||||
self._dtype)
|
||||
|
||||
def _unbatch(self):
|
||||
if self._shape.ndims == 0:
|
||||
raise ValueError("Unbatching a tensor is only supported for rank >= 1")
|
||||
return TensorSpec(self._shape[1:], self._dtype)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return ops.Tensor
|
||||
|
||||
|
||||
# TODO(b/133606651): Should is_compatible_with should check min/max bounds?
|
||||
class BoundedTensorSpec(TensorSpec):
|
||||
"""A `TensorSpec` that specifies minimum and maximum values.
|
||||
|
||||
@ -216,3 +282,15 @@ class BoundedTensorSpec(TensorSpec):
|
||||
def __reduce__(self):
|
||||
return BoundedTensorSpec, (self._shape, self._dtype, self._minimum,
|
||||
self._maximum, self._name)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype, self._minimum, self._maximum, self._name)
|
||||
|
||||
|
||||
pywrap_tensorflow.RegisterType("TensorSpec", TensorSpec)
|
||||
|
||||
|
||||
# Note: we do not include Tensor names when constructing TypeSpecs.
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
ops.Tensor,
|
||||
lambda tensor: TensorSpec(tensor.shape, tensor.dtype))
|
||||
|
@ -24,14 +24,21 @@ import six
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.util import compat
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import tf_decorator
|
||||
from tensorflow.python.util.lazy_loader import LazyLoader
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# Use LazyLoader to avoid circular dependencies.
|
||||
tensor_spec = LazyLoader(
|
||||
"tensor_spec", globals(),
|
||||
"tensorflow.python.framework.tensor_spec")
|
||||
ops = LazyLoader(
|
||||
"ops", globals(),
|
||||
"tensorflow.python.framework.ops")
|
||||
|
||||
|
||||
# TODO(b/133606651) Export this as "TypeSpec" (or experimental.TypeSpec?) and
|
||||
# deprecate the tf.data.experimental.Structure endpoint.
|
||||
@ -81,7 +88,7 @@ class TypeSpec(object):
|
||||
# compatibility using their `is_comptaible_with` method; and all other
|
||||
# types are considered compatible if they are equal).
|
||||
if not isinstance(spec_or_value, TypeSpec):
|
||||
spec_or_value = TypeSpec.from_value(spec_or_value)
|
||||
spec_or_value = type_spec_from_value(spec_or_value)
|
||||
if type(self) is not type(spec_or_value):
|
||||
return False
|
||||
return self.__is_compatible(self._serialize(),
|
||||
@ -111,44 +118,6 @@ class TypeSpec(object):
|
||||
self._serialize(), other._serialize()) # pylint: disable=protected-access
|
||||
return self._deserialize(merged)
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
"""Returns a `TypeSpec` that represents the given `value`.
|
||||
|
||||
Args:
|
||||
value: A value that can be accepted or returned by TensorFlow APIs.
|
||||
|
||||
Returns:
|
||||
A `TypeSpec` that is compatible with `value`.
|
||||
|
||||
Raises:
|
||||
TypeError: If a TypeSpec cannot be built for `value`, because its type
|
||||
is not supported.
|
||||
"""
|
||||
if isinstance(value, ops.Tensor):
|
||||
# Note: we do not include Tensor names when constructing TypeSpecs.
|
||||
return tensor_spec.TensorSpec(value.shape, value.dtype)
|
||||
|
||||
# TODO(b/133606651) Uncomment the following two lines when CompositeTensor
|
||||
# is updated to define a _type_spec field:
|
||||
#
|
||||
# if isinstance(value, composite_tensor.CompositeTensor):
|
||||
# return value._type_spec # pylint: disable=protected-access
|
||||
|
||||
for entry in reversed(TypeSpec._TYPE_CONVERSION_FUNCTION_REGISTRY):
|
||||
type_object, converter_fn, allow_subclass = entry
|
||||
if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck
|
||||
(allow_subclass and isinstance(value, type_object))):
|
||||
return converter_fn(value)
|
||||
|
||||
# Fallback: try converting value to a tensor.
|
||||
try:
|
||||
tensor = ops.convert_to_tensor(value)
|
||||
except (ValueError, TypeError):
|
||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||
(value, type(value).__name__))
|
||||
return TypeSpec.from_value(tensor)
|
||||
|
||||
# === Component encoding for values ===
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -331,7 +300,7 @@ class TypeSpec(object):
|
||||
|
||||
def __check_tensor_list(self, tensor_list):
|
||||
expected = self._flat_tensor_specs
|
||||
specs = [tensor_spec.TensorSpec.from_tensor(t) for t in tensor_list]
|
||||
specs = [type_spec_from_value(t) for t in tensor_list]
|
||||
if len(specs) != len(expected):
|
||||
raise ValueError("Incompatible input: wrong number of tensors")
|
||||
for i, (s1, s2) in enumerate(zip(specs, expected)):
|
||||
@ -346,8 +315,7 @@ class TypeSpec(object):
|
||||
|
||||
def __make_cmp_key(self, value):
|
||||
"""Converts `value` to a hashable key."""
|
||||
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec,
|
||||
tensor_spec.TensorSpec)):
|
||||
if isinstance(value, (int, float, bool, dtypes.DType, TypeSpec)):
|
||||
return value
|
||||
elif isinstance(value, compat.bytes_or_text_types):
|
||||
return value
|
||||
@ -386,8 +354,7 @@ class TypeSpec(object):
|
||||
if isinstance(a, tuple):
|
||||
return (len(a) == len(b) and
|
||||
all(TypeSpec.__is_compatible(x, y) for (x, y) in zip(a, b)))
|
||||
elif isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType,
|
||||
tensor_spec.TensorSpec)):
|
||||
elif isinstance(a, (TypeSpec, tensor_shape.TensorShape, dtypes.DType)):
|
||||
return a.is_compatible_with(b)
|
||||
else:
|
||||
return a == b
|
||||
@ -428,12 +395,6 @@ class TypeSpec(object):
|
||||
for (x, y) in zip(a, b))
|
||||
elif isinstance(a, tensor_shape.TensorShape):
|
||||
return a.most_specific_compatible_shape(b)
|
||||
elif isinstance(a, tensor_spec.TensorSpec):
|
||||
if a.dtype != b.dtype:
|
||||
raise ValueError("Types are not compatible: %r vs %r" % (a, b))
|
||||
shape = a.shape.most_specific_compatible_shape(b.shape)
|
||||
name = a.name if a.name == b.name else None
|
||||
return tensor_spec.TensorSpec(shape, a.dtype, name)
|
||||
elif isinstance(a, list):
|
||||
raise AssertionError("_serialize() should not return list values.")
|
||||
elif isinstance(a, TypeSpec):
|
||||
@ -491,6 +452,55 @@ class BatchableTypeSpec(TypeSpec):
|
||||
return tensor_list
|
||||
|
||||
|
||||
def type_spec_from_value(value):
|
||||
"""Returns a `TypeSpec` that represents the given `value`.
|
||||
|
||||
Args:
|
||||
value: A value that can be accepted or returned by TensorFlow APIs.
|
||||
|
||||
Returns:
|
||||
A `TypeSpec` that is compatible with `value`.
|
||||
|
||||
Raises:
|
||||
TypeError: If a TypeSpec cannot be built for `value`, because its type
|
||||
is not supported.
|
||||
"""
|
||||
spec = _type_spec_from_value(value)
|
||||
if spec is not None:
|
||||
return spec
|
||||
|
||||
# Fallback: try converting value to a tensor.
|
||||
try:
|
||||
tensor = ops.convert_to_tensor(value)
|
||||
spec = _type_spec_from_value(tensor)
|
||||
if spec is not None:
|
||||
return spec
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
raise TypeError("Could not build a TypeSpec for %r with type %s" %
|
||||
(value, type(value).__name__))
|
||||
|
||||
|
||||
def _type_spec_from_value(value):
|
||||
"""Returns a `TypeSpec` that represents the given `value`."""
|
||||
if isinstance(value, ops.Tensor):
|
||||
# Note: we do not include Tensor names when constructing TypeSpecs.
|
||||
return tensor_spec.TensorSpec(value.shape, value.dtype)
|
||||
|
||||
# TODO(b/133606651) Uncomment the following two lines when CompositeTensor
|
||||
# is updated to define a _type_spec field:
|
||||
#
|
||||
# if isinstance(value, composite_tensor.CompositeTensor):
|
||||
# return value._type_spec # pylint: disable=protected-access
|
||||
|
||||
for entry in reversed(_TYPE_CONVERSION_FUNCTION_REGISTRY):
|
||||
type_object, converter_fn, allow_subclass = entry
|
||||
if ((type(value) is type_object) or # pylint: disable=unidiomatic-typecheck
|
||||
(allow_subclass and isinstance(value, type_object))):
|
||||
return converter_fn(value)
|
||||
|
||||
|
||||
_TYPE_CONVERSION_FUNCTION_REGISTRY = []
|
||||
|
||||
|
||||
|
@ -278,5 +278,10 @@ class TypeSpecTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
(tensor_shape.TensorShape([5, 3]), dtypes.int32,
|
||||
tensor_shape.TensorShape(None), dtypes.bool, "red"))
|
||||
|
||||
def testFromValue(self):
|
||||
value = TwoTensors([1, 2, 3], [1.0, 2.0], "red")
|
||||
spec = type_spec.type_spec_from_value(value)
|
||||
self.assertEqual(spec, TwoTensorsSpec.from_value(value))
|
||||
|
||||
if __name__ == "__main__":
|
||||
googletest.main()
|
||||
|
@ -695,6 +695,29 @@ class LSTMV2Test(keras_parameterized.TestCase):
|
||||
},
|
||||
input_shape=(num_samples, timesteps, embedding_dim))
|
||||
|
||||
def test_bidirectional(self):
|
||||
batch = 128
|
||||
timestep = 20
|
||||
vocab_size = 1000
|
||||
model = keras.Sequential([
|
||||
keras.layers.Embedding(vocab_size, 64),
|
||||
keras.layers.Bidirectional(rnn.LSTM(
|
||||
64, return_sequences=True)),
|
||||
keras.layers.Bidirectional(rnn.LSTM(32)),
|
||||
keras.layers.Dense(64, activation='relu'),
|
||||
keras.layers.Dense(1, activation='sigmoid')
|
||||
])
|
||||
|
||||
model.compile(loss='binary_crossentropy',
|
||||
optimizer='adam',
|
||||
metrics=['accuracy'])
|
||||
|
||||
x = np.random.randint(0, vocab_size, size=(batch, timestep))
|
||||
y = np.random.randint(0, 1, size=(batch))
|
||||
model.fit(x, y, epochs=1, shuffle=False)
|
||||
model.evaluate(x, y)
|
||||
model.predict(x)
|
||||
|
||||
|
||||
class LSTMLayerGraphOnlyTest(test.TestCase):
|
||||
|
||||
@ -803,7 +826,7 @@ class LSTMLayerGraphOnlyTest(test.TestCase):
|
||||
existing_loss = loss_value
|
||||
|
||||
|
||||
class UnifiedLSTMPerformanceTest(test.Benchmark):
|
||||
class UnifiedLSTMPerformanceTest(test.TestCase):
|
||||
|
||||
def _measure_performance(self, test_config, model, x_train, y_train):
|
||||
batch = test_config['batch']
|
||||
@ -913,29 +936,29 @@ class UnifiedLSTMPerformanceTest(test.Benchmark):
|
||||
cudnn_vs_unified = cudnn_sec_per_epoch / unified_lstm_sec_per_epoch
|
||||
unified_vs_normal = normal_lstm_sec_per_epoch / unified_lstm_sec_per_epoch
|
||||
|
||||
self.report_benchmark(name='keras_cudnn_lstm_' + mode,
|
||||
wall_time=cudnn_sec_per_epoch,
|
||||
iters=test_config['epoch'],
|
||||
extras=test_config)
|
||||
self.report_benchmark(name='keras_unified_lstm_' + mode,
|
||||
wall_time=unified_lstm_sec_per_epoch,
|
||||
iters=test_config['epoch'],
|
||||
extras=test_config)
|
||||
self.report_benchmark(name='keras_canonical_lstm_' + mode,
|
||||
wall_time=normal_lstm_sec_per_epoch,
|
||||
iters=test_config['epoch'],
|
||||
extras=test_config)
|
||||
# self.report_benchmark(name='keras_cudnn_lstm_' + mode,
|
||||
# wall_time=cudnn_sec_per_epoch,
|
||||
# iters=test_config['epoch'],
|
||||
# extras=test_config)
|
||||
# self.report_benchmark(name='keras_unified_lstm_' + mode,
|
||||
# wall_time=unified_lstm_sec_per_epoch,
|
||||
# iters=test_config['epoch'],
|
||||
# extras=test_config)
|
||||
# self.report_benchmark(name='keras_canonical_lstm_' + mode,
|
||||
# wall_time=normal_lstm_sec_per_epoch,
|
||||
# iters=test_config['epoch'],
|
||||
# extras=test_config)
|
||||
|
||||
logging.info('Expect the performance of Unified LSTM is within 80% of '
|
||||
'CuDNN LSTM, got {0:.2f}%'.format(cudnn_vs_unified * 100))
|
||||
logging.info('Expect the performance of Unified LSTM is more than 5 times'
|
||||
' of normal LSTM, got {0:.2f}'.format(unified_vs_normal))
|
||||
|
||||
def benchmark_performance_graph(self):
|
||||
def test_performance_graph(self):
|
||||
with context.graph_mode(), session_lib.Session(config=_config):
|
||||
self._benchmark_performance_with_standard_cudnn_impl()
|
||||
|
||||
def benchmark_performance_eager(self):
|
||||
def test_performance_eager(self):
|
||||
with context.eager_mode():
|
||||
self._benchmark_performance_with_standard_cudnn_impl()
|
||||
|
||||
|
@ -398,28 +398,23 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU):
|
||||
else:
|
||||
last_output, outputs, new_h, runtime = standard_gru(**normal_gru_kwargs)
|
||||
else:
|
||||
api_name = 'gru_' + str(uuid.uuid4())
|
||||
defun_standard_gru = _generate_defun_backend(
|
||||
api_name, _CPU_DEVICE_NAME, standard_gru)
|
||||
defun_cudnn_gru = _generate_defun_backend(
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_gru)
|
||||
# Call the normal GRU impl and register the CuDNN impl function. The
|
||||
# grappler will kick in during session execution to optimize the graph.
|
||||
last_output, outputs, new_h, runtime = defun_standard_gru(
|
||||
**normal_gru_kwargs)
|
||||
|
||||
def register_cudnn_defun():
|
||||
function.register(defun_cudnn_gru, **cudnn_gru_kwargs)
|
||||
# return some dummy value since the tf.cond require some return value.
|
||||
return 0
|
||||
if mask is None:
|
||||
register_cudnn_defun()
|
||||
last_output, outputs, new_h, runtime = gru_with_backend_selection(
|
||||
normal_gru_kwargs, cudnn_gru_kwargs)
|
||||
else:
|
||||
# Only when seq_right_padded=True, CuDNN kernel can support that
|
||||
# properly.
|
||||
control_flow_ops.cond(is_sequence_right_padded(mask, self.time_major),
|
||||
true_fn=register_cudnn_defun,
|
||||
false_fn=lambda: 0)
|
||||
def with_mask_support():
|
||||
# TODO(b/134702514): Change to use backend selection.
|
||||
# return gru_with_backend_selection(normal_gru_kwargs,
|
||||
# cudnn_gru_kwargs)
|
||||
return standard_gru(**normal_gru_kwargs)
|
||||
def without_mask_support():
|
||||
return standard_gru(**normal_gru_kwargs)
|
||||
|
||||
last_output, outputs, new_h, runtime = control_flow_ops.cond(
|
||||
is_sequence_right_padded(mask, self.time_major),
|
||||
true_fn=with_mask_support,
|
||||
false_fn=without_mask_support)
|
||||
|
||||
states = [new_h]
|
||||
return last_output, outputs, runtime, states
|
||||
|
||||
@ -532,22 +527,20 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
|
||||
if mask is not None:
|
||||
sequence_length = calculate_sequence_by_mask(mask, time_major)
|
||||
if go_backwards:
|
||||
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length,
|
||||
seq_axis=0, batch_axis=1)
|
||||
outputs, h, _, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
|
||||
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
|
||||
rnn_mode='gru', sequence_lengths=sequence_length)
|
||||
else:
|
||||
# Fill the array with shape [batch] with value of max timesteps.
|
||||
sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
|
||||
array_ops.shape(inputs)[0])
|
||||
if go_backwards:
|
||||
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length, seq_axis=0,
|
||||
batch_axis=1)
|
||||
if go_backwards:
|
||||
# Reverse axis 0 since the input is already convert to time major.
|
||||
inputs = array_ops.reverse(inputs, axis=[0])
|
||||
outputs, h, _, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
inputs, input_h=init_h, input_c=0, params=params, is_training=True,
|
||||
rnn_mode='gru')
|
||||
|
||||
outputs, h, _, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
|
||||
inputs,
|
||||
input_h=init_h,
|
||||
input_c=0,
|
||||
params=params,
|
||||
is_training=True,
|
||||
rnn_mode='gru',
|
||||
sequence_lengths=sequence_length)
|
||||
last_output = outputs[-1]
|
||||
if not time_major:
|
||||
outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
|
||||
@ -565,6 +558,44 @@ def cudnn_gru(inputs, init_h, kernel, recurrent_kernel, bias, mask, time_major,
|
||||
return last_output, outputs, h, _runtime(_RUNTIME_GPU)
|
||||
|
||||
|
||||
def gru_with_backend_selection(normal_gru_params, cudnn_gru_params):
|
||||
"""Call the GRU with optimized backend kernel selection.
|
||||
|
||||
Under the hood, this function will create two TF function, one with the most
|
||||
generic kernel and can run on all device condition, and the second one with
|
||||
CuDNN specific kernel, which can only run on GPU.
|
||||
|
||||
The first function will be called with normal_lstm_params, while the second
|
||||
function is not called, but only registered in the graph. The Grappler will
|
||||
do the proper graph rewrite and swap the optimized TF function based on the
|
||||
device placement.
|
||||
|
||||
Args:
|
||||
normal_gru_params: Dict, parameters for the generic TF function.
|
||||
cudnn_gru_params: Dict, parameters for the CuDNN specific TF function.
|
||||
|
||||
Returns:
|
||||
List of output tensors, same as standard_gru.
|
||||
"""
|
||||
# Each time a `tf.function` is called, we will give it a unique
|
||||
# identifiable API name, so that Grappler won't get confused when it
|
||||
# sees multiple GRU layers added into same graph, and it will be able
|
||||
# to pair up the different implementations across them.
|
||||
api_name = 'gru_' + str(uuid.uuid4())
|
||||
defun_standard_gru = _generate_defun_backend(
|
||||
api_name, _CPU_DEVICE_NAME, standard_gru)
|
||||
defun_cudnn_gru = _generate_defun_backend(
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_gru)
|
||||
|
||||
# Call the normal GRU impl and register the CuDNN impl function. The
|
||||
# grappler will kick in during session execution to optimize the graph.
|
||||
last_output, outputs, new_h, runtime = defun_standard_gru(
|
||||
**normal_gru_params)
|
||||
|
||||
function.register(defun_cudnn_gru, **cudnn_gru_params)
|
||||
return last_output, outputs, new_h, runtime
|
||||
|
||||
|
||||
@keras_export('keras.layers.LSTMCell', v1=[])
|
||||
class LSTMCell(recurrent.LSTMCell):
|
||||
"""Cell class for the LSTM layer.
|
||||
@ -876,33 +907,25 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM):
|
||||
last_output, outputs, new_h, new_c, runtime = standard_lstm(
|
||||
**normal_lstm_kwargs)
|
||||
else:
|
||||
# Each time a `tf.function` is called, we will give it a unique
|
||||
# identifiable API name, so that Grappler won't get confused when it
|
||||
# sees multiple LSTM layers added into same graph, and it will be able
|
||||
# to pair up the different implementations across them.
|
||||
api_name = 'lstm_' + str(uuid.uuid4())
|
||||
defun_standard_lstm = _generate_defun_backend(
|
||||
api_name, _CPU_DEVICE_NAME, standard_lstm)
|
||||
defun_cudnn_lstm = _generate_defun_backend(
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_lstm)
|
||||
|
||||
# Call the normal LSTM impl and register the CuDNN impl function. The
|
||||
# grappler will kick in during session execution to optimize the graph.
|
||||
last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
|
||||
**normal_lstm_kwargs)
|
||||
|
||||
def register_cudnn_defun():
|
||||
function.register(defun_cudnn_lstm, **cudnn_lstm_kwargs)
|
||||
# return some dummy value since the tf.cond require some return value.
|
||||
return 0
|
||||
if mask is None:
|
||||
register_cudnn_defun()
|
||||
(last_output, outputs,
|
||||
new_h, new_c, runtime) = lstm_with_backend_selection(
|
||||
normal_lstm_kwargs, cudnn_lstm_kwargs)
|
||||
else:
|
||||
# Only when seq_right_padded=True, CuDNN kernel can support that
|
||||
# properly.
|
||||
control_flow_ops.cond(is_sequence_right_padded(mask, self.time_major),
|
||||
true_fn=register_cudnn_defun,
|
||||
false_fn=lambda: 0)
|
||||
def with_mask_support():
|
||||
# TODO(b/134702514): Change to use backend selection.
|
||||
# return lstm_with_backend_selection(normal_lstm_kwargs,
|
||||
# cudnn_lstm_kwargs)
|
||||
return standard_lstm(**normal_lstm_kwargs)
|
||||
def without_mask_support():
|
||||
return standard_lstm(**normal_lstm_kwargs)
|
||||
|
||||
(last_output, outputs,
|
||||
new_h, new_c, runtime) = control_flow_ops.cond(
|
||||
is_sequence_right_padded(mask, self.time_major),
|
||||
true_fn=with_mask_support,
|
||||
false_fn=without_mask_support)
|
||||
|
||||
states = [new_h, new_c]
|
||||
|
||||
if self.stateful:
|
||||
@ -1076,25 +1099,31 @@ def cudnn_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
||||
# so that mathematically it is same as the canonical LSTM implementation.
|
||||
full_bias = array_ops.concat((array_ops.zeros_like(bias), bias), 0)
|
||||
|
||||
if mask is not None:
|
||||
sequence_length = calculate_sequence_by_mask(mask, time_major)
|
||||
else:
|
||||
# Fill the array with shape [batch] with value of max timesteps.
|
||||
sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
|
||||
array_ops.shape(inputs)[0])
|
||||
if go_backwards:
|
||||
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length, seq_axis=0,
|
||||
batch_axis=1)
|
||||
|
||||
params = _canonical_to_params(
|
||||
weights=weights,
|
||||
biases=array_ops.split(full_bias, 8),
|
||||
shape=constant_op.constant([-1]),
|
||||
transpose_weights=True)
|
||||
|
||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
|
||||
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
|
||||
rnn_mode='lstm', sequence_lengths=sequence_length)
|
||||
if mask is not None:
|
||||
sequence_length = calculate_sequence_by_mask(mask, time_major)
|
||||
if go_backwards:
|
||||
inputs = array_ops.reverse_sequence_v2(inputs, sequence_length,
|
||||
seq_axis=0, batch_axis=1)
|
||||
outputs, h, c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv3(
|
||||
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
|
||||
rnn_mode='lstm', sequence_lengths=sequence_length)
|
||||
else:
|
||||
# # Fill the array with shape [batch] with value of max timesteps.
|
||||
# sequence_length = array_ops.fill([array_ops.shape(inputs)[1]],
|
||||
# array_ops.shape(inputs)[0])
|
||||
if go_backwards:
|
||||
# Reverse axis 0 since the input is already convert to time major.
|
||||
inputs = array_ops.reverse(inputs, axis=[0])
|
||||
outputs, h, c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
|
||||
inputs, input_h=init_h, input_c=init_c, params=params, is_training=True,
|
||||
rnn_mode='lstm')
|
||||
|
||||
last_output = outputs[-1]
|
||||
if not time_major:
|
||||
outputs = array_ops.transpose(outputs, perm=[1, 0, 2])
|
||||
@ -1112,6 +1141,44 @@ def cudnn_lstm(inputs, init_h, init_c, kernel, recurrent_kernel, bias, mask,
|
||||
return last_output, outputs, h, c, _runtime(_RUNTIME_GPU)
|
||||
|
||||
|
||||
def lstm_with_backend_selection(normal_lstm_params, cudnn_lstm_params):
|
||||
"""Call the LSTM with optimized backend kernel selection.
|
||||
|
||||
Under the hood, this function will create two TF function, one with the most
|
||||
generic kernel and can run on all device condition, and the second one with
|
||||
CuDNN specific kernel, which can only run on GPU.
|
||||
|
||||
The first function will be called with normal_lstm_params, while the second
|
||||
function is not called, but only registered in the graph. The Grappler will
|
||||
do the proper graph rewrite and swap the optimized TF function based on the
|
||||
device placement.
|
||||
|
||||
Args:
|
||||
normal_lstm_params: Dict, parameters for the generic TF function.
|
||||
cudnn_lstm_params: Dict, parameters for the CuDNN specific TF function.
|
||||
|
||||
Returns:
|
||||
List of output tensors, same as standard_lstm.
|
||||
"""
|
||||
# Each time a `tf.function` is called, we will give it a unique
|
||||
# identifiable API name, so that Grappler won't get confused when it
|
||||
# sees multiple LSTM layers added into same graph, and it will be able
|
||||
# to pair up the different implementations across them.
|
||||
api_name = 'lstm_' + str(uuid.uuid4())
|
||||
defun_standard_lstm = _generate_defun_backend(
|
||||
api_name, _CPU_DEVICE_NAME, standard_lstm)
|
||||
defun_cudnn_lstm = _generate_defun_backend(
|
||||
api_name, _GPU_DEVICE_NAME, cudnn_lstm)
|
||||
|
||||
# Call the normal LSTM impl and register the CuDNN impl function. The
|
||||
# grappler will kick in during session execution to optimize the graph.
|
||||
last_output, outputs, new_h, new_c, runtime = defun_standard_lstm(
|
||||
**normal_lstm_params)
|
||||
|
||||
function.register(defun_cudnn_lstm, **cudnn_lstm_params)
|
||||
return last_output, outputs, new_h, new_c, runtime
|
||||
|
||||
|
||||
def is_sequence_right_padded(mask, time_major):
|
||||
"""Check the mask tensor and see if it right padded.
|
||||
|
||||
@ -1186,6 +1253,9 @@ def _generate_defun_backend(unique_api_name, preferred_device, func):
|
||||
function_attributes = {
|
||||
_DEFUN_API_NAME_ATTRIBUTE: unique_api_name,
|
||||
_DEFUN_DEVICE_ATTRIBUTE: preferred_device,
|
||||
# TODO(b/133178886): The function is auto inlined in eager context, which
|
||||
# make grappler fail to do the optimization. Force it to not inline here.
|
||||
'_noinline': True,
|
||||
}
|
||||
return function.defun_with_attributes(func=func,
|
||||
attributes=function_attributes)
|
||||
|
@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
@ -27,6 +29,7 @@ from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
@ -241,6 +244,7 @@ class RaggedTensor(composite_tensor.CompositeTensor):
|
||||
"of the factory methods instead (e.g., "
|
||||
"RaggedTensor.from_row_lengths())")
|
||||
|
||||
# TODO(b/133606651) Remove is_tensor_spec code paths -- replaced by TypeSpec
|
||||
is_tensor_spec = isinstance(row_splits, tensor_spec.TensorSpec)
|
||||
if is_tensor_spec:
|
||||
if not (isinstance(values, tensor_spec.TensorSpec) or
|
||||
@ -1916,6 +1920,161 @@ def match_row_splits_dtypes(*tensors, **kwargs):
|
||||
return tensors
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# RaggedTensorSpec
|
||||
#===============================================================================
|
||||
# TODO(b/133606651) Export this as tf.RaggedTensorSpec.
|
||||
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for a `tf.RaggedTensor`."""
|
||||
|
||||
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype"]
|
||||
|
||||
value_type = property(lambda self: RaggedTensor)
|
||||
|
||||
def __init__(self, shape=None, dtype=dtypes.float32, ragged_rank=None,
|
||||
row_splits_dtype=dtypes.int64):
|
||||
"""Constructs a type specification for a `tf.RaggedTensor`.
|
||||
|
||||
Args:
|
||||
shape: The shape of the RaggedTensor, or `None` to allow any shape. If
|
||||
a shape is specified, then all ragged dimensions must have size `None`.
|
||||
dtype: `tf.DType` of values in the RaggedTensor.
|
||||
ragged_rank: Python integer, the ragged rank of the RaggedTensor
|
||||
to be described. Defaults to `shape.ndims - 1`.
|
||||
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor.
|
||||
One of `tf.int32` or `tf.int64`.
|
||||
"""
|
||||
self._shape = tensor_shape.as_shape(shape)
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
|
||||
|
||||
rank = self._shape.ndims
|
||||
if ragged_rank is None:
|
||||
if rank is None:
|
||||
raise ValueError("Must specify ragged_rank or "
|
||||
"a shape with a known rank.")
|
||||
ragged_rank = rank - 1
|
||||
self._ragged_rank = ragged_rank
|
||||
if not isinstance(self._ragged_rank, int):
|
||||
raise TypeError("ragged_rank must be an int")
|
||||
|
||||
if rank is not None:
|
||||
if ragged_rank >= rank:
|
||||
raise ValueError("ragged_rank must be less than rank.")
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
if self._ragged_rank == 0:
|
||||
return [tensor_spec.TensorSpec(self._shape, self._dtype)]
|
||||
|
||||
flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[self._ragged_rank + 1:])
|
||||
outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
|
||||
outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
|
||||
inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
|
||||
|
||||
specs = (
|
||||
[tensor_spec.TensorSpec(flat_values_shape, self._dtype),
|
||||
tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)] +
|
||||
[inner_splits_spec for _ in range(self._ragged_rank - 1)])
|
||||
return specs
|
||||
|
||||
def _to_components(self, value):
|
||||
if is_ragged(value):
|
||||
return [value.flat_values] + list(value.nested_row_splits)
|
||||
else:
|
||||
return [value]
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
# Currently, Keras converts tensors to numpy and then calls from_components
|
||||
# with those np.arrays. So if we see np.ndarrays, convert them to tensors.
|
||||
# TODO(b/133606651) Update Keras to do something different here. Consider
|
||||
# adding something like TypeSpec.from_numpy_components?
|
||||
if isinstance(tensor_list[0], np.ndarray):
|
||||
tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
|
||||
|
||||
result = tensor_list[0]
|
||||
for row_splits in reversed(tensor_list[1:]):
|
||||
result = RaggedTensor(result, row_splits, internal=True)
|
||||
return result
|
||||
|
||||
# The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
|
||||
# to (un)box the component tensors in a way that allows for batching &
|
||||
# unbatching.
|
||||
@property
|
||||
def _flat_tensor_specs(self):
|
||||
# NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
|
||||
# `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of
|
||||
# boxed `RaggedTensor` objects with shape `(...)` (and batches of batches,
|
||||
# etc.), so the flat shape must be unknown.
|
||||
return [tensor_spec.TensorSpec(None, dtypes.variant)]
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
# pylint: disable=protected-access
|
||||
return [value._to_variant(batched_input=False)]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
# pylint: disable=protected-access
|
||||
return [value._to_variant(batched_input=True)]
|
||||
|
||||
def _from_compatible_tensor_list(self, tensor_list):
|
||||
if self._ragged_rank <= 0:
|
||||
raise ValueError(
|
||||
"ragged_rank must be non-negative; got %s." % self._ragged_rank)
|
||||
result = RaggedTensor._from_variant( # pylint: disable=protected-access
|
||||
tensor_list[0], dtype=self._dtype,
|
||||
output_ragged_rank=self._ragged_rank)
|
||||
if self._shape.ndims is not None:
|
||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||
if outer_dim is not None:
|
||||
result.row_splits.set_shape([outer_dim + 1])
|
||||
result.flat_values.set_shape(
|
||||
tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[1 + self._ragged_rank:]))
|
||||
return result
|
||||
|
||||
def _batch(self, batch_size):
|
||||
return RaggedTensorSpec(
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
|
||||
self._dtype,
|
||||
self._ragged_rank + 1)
|
||||
|
||||
def _unbatch(self):
|
||||
# Note: Negative ragged_rank is allowed here because the dataset could
|
||||
# be subsequently batched again. Errors are handled in
|
||||
# RaggedTensorSpec._from_compatible_tensor_list()
|
||||
return RaggedTensorSpec(self._shape[1:], self._dtype,
|
||||
self._ragged_rank - 1)
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
return self._shape
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
return cls(shape=value.shape,
|
||||
dtype=value.values.dtype,
|
||||
ragged_rank=value.ragged_rank,
|
||||
row_splits_dtype=value.row_splits.dtype)
|
||||
|
||||
|
||||
# TODO(b/133606651) Delete the RaggedTensor registration when CompositeTensor
|
||||
# is updated to define a _type_spec field (since registration will be
|
||||
# automatic). Do *not* delete the RaggedTensorValue registration.
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
RaggedTensor, RaggedTensorSpec.from_value)
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Convert value -> tensor
|
||||
#===============================================================================
|
||||
|
@ -29,7 +29,9 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_util
|
||||
from tensorflow.python.ops import gen_control_flow_ops
|
||||
@ -1342,3 +1344,115 @@ def _check_dtypes(value, dtype):
|
||||
"in future versions of TensorFlow. Traceback:\n{}".format(
|
||||
value, str(value.dtype), str(dtype),
|
||||
"".join(traceback.format_stack())))
|
||||
|
||||
|
||||
# TODO(b/133606651) Export this as tf.TensorArraySpec.
|
||||
class TensorArraySpec(type_spec.TypeSpec):
|
||||
"""Type specification for a `tf.TensorArray`."""
|
||||
|
||||
__slots__ = ["_element_shape", "_dtype", "_dynamic_size", "_infer_shape"]
|
||||
|
||||
value_type = property(lambda self: TensorArray)
|
||||
|
||||
def __init__(self, element_shape=None, dtype=dtypes.float32,
|
||||
dynamic_size=False, infer_shape=True):
|
||||
"""Constructs a type specification for a `tf.TensorArray`.
|
||||
|
||||
Args:
|
||||
element_shape: The shape of each element in the `TensorArray`.
|
||||
dtype: Data type of the `TensorArray`.
|
||||
dynamic_size: Whether the `TensorArray` can grow past its initial size.
|
||||
infer_shape: Whether shape inference is enabled.
|
||||
"""
|
||||
self._element_shape = tensor_shape.as_shape(element_shape)
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._dynamic_size = dynamic_size
|
||||
self._infer_shape = infer_shape
|
||||
|
||||
def is_compatible_with(self, other):
|
||||
# We check all fields *except* infer_shape.
|
||||
# TODO(b/133606651) Verify that this is the correct behavior.
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if not isinstance(other, type_spec.TypeSpec):
|
||||
other = type_spec.type_spec_from_value(other)
|
||||
return (isinstance(other, TensorArraySpec) and
|
||||
self._dtype.is_compatible_with(other._dtype) and
|
||||
self._element_shape.is_compatible_with(other._element_shape) and
|
||||
self._dynamic_size == other._dynamic_size)
|
||||
|
||||
def most_specific_compatible_type(self, other):
|
||||
# TODO(b/133606651) Verify that this is the correct behavior for combining
|
||||
# infer_shape values.
|
||||
|
||||
# pylint: disable=protected-access
|
||||
if not self.is_compatible_with(other):
|
||||
raise ValueError("Types are not compatible")
|
||||
infer_shape = self._infer_shape or other._infer_shape
|
||||
return TensorArraySpec(
|
||||
self._element_shape.most_specific_compatible_shape(
|
||||
other._element_shape),
|
||||
self._dtype, self._dynamic_size, infer_shape)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._element_shape, self._dtype, self._dynamic_size,
|
||||
self._infer_shape)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
return [tensor_spec.TensorSpec([], dtypes.variant)]
|
||||
|
||||
def _to_components(self, value):
|
||||
if not isinstance(value, TensorArray):
|
||||
raise TypeError("value must be a TensorArray, but saw: {}"
|
||||
.format(type(value)))
|
||||
if value.flow is not None and value.flow.dtype == dtypes.variant:
|
||||
return [value.flow]
|
||||
else:
|
||||
# Convert to a TF2-style TensorArray.
|
||||
# TODO(ebrevdo): Add an "_as_variant" method to TensorArray class, or
|
||||
# "implementation / as_variant" arg to TensorArray constructor.
|
||||
with ops.name_scope("convert_tensor_array"):
|
||||
flow = list_ops.tensor_list_from_tensor(
|
||||
tensor=value.stack(), element_shape=value.element_shape)
|
||||
return [flow]
|
||||
|
||||
def _from_components(self, tensor_list):
|
||||
# This will return a TF2 Graph-style TensorArray because tensor_list[0] is
|
||||
# a variant object. size == -1 implies unknown size.
|
||||
ret = TensorArray(
|
||||
dtype=self._dtype,
|
||||
flow=tensor_list[0],
|
||||
dynamic_size=self._dynamic_size,
|
||||
infer_shape=self._infer_shape)
|
||||
ret._element_shape = [self._element_shape] # pylint: disable=protected-access
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def from_value(value):
|
||||
if not isinstance(value, TensorArray):
|
||||
raise TypeError("Expected value to be a TensorArray, but saw: {}".
|
||||
format(type(value)))
|
||||
|
||||
return TensorArraySpec(
|
||||
dtype=value.dtype,
|
||||
element_shape=value.element_shape,
|
||||
dynamic_size=value.dynamic_size,
|
||||
infer_shape=value._infer_shape) # pylint: disable=protected-access
|
||||
|
||||
def _to_legacy_output_types(self):
|
||||
return self._dtype
|
||||
|
||||
def _to_legacy_output_shapes(self):
|
||||
# Sneak the dynamic_size and infer_shape values into the legacy shape.
|
||||
return (tensor_shape.matrix(self._dynamic_size, self._infer_shape)
|
||||
.concatenate(self._element_shape))
|
||||
|
||||
def _to_legacy_output_classes(self):
|
||||
return TensorArray
|
||||
|
||||
|
||||
# Register the TypeSpec for TensorArray. If TensorArray is updated to be a
|
||||
# CompositeTensor, then this registration can be deleted.
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
TensorArray, TensorArraySpec.from_value, allow_subclass=True)
|
||||
|
@ -1,6 +1,8 @@
|
||||
path: "tensorflow.TensorSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
@ -14,6 +16,10 @@ tf_class {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
@ -30,4 +36,8 @@ tf_class {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,15 @@
|
||||
path: "tensorflow.data.experimental.DatasetStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'element_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
@ -13,6 +17,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,13 @@
|
||||
path: "tensorflow.data.experimental.NestedStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -15,4 +20,8 @@ tf_class {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,12 @@
|
||||
path: "tensorflow.data.experimental.OptionalStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.OptionalStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -13,6 +17,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.RaggedTensorStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.RaggedTensorStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.SparseTensorStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.SparseTensorStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,16 +1,20 @@
|
||||
path: "tensorflow.data.experimental.Structure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.TensorArrayStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.TensorStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.TensorStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -44,10 +44,6 @@ tf_module {
|
||||
name: "OptionalStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RaggedTensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RandomDataset"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -56,10 +52,6 @@ tf_module {
|
||||
name: "Reducer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "SparseTensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "SqlDataset"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -80,14 +72,6 @@ tf_module {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorArrayStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "ThreadingOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -100,6 +84,22 @@ tf_module {
|
||||
name: "Counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "SparseTensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "TensorArrayStructure"
|
||||
argspec: "args=[\'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "TensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "bucket_by_sequence_length"
|
||||
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
||||
|
@ -1,6 +1,8 @@
|
||||
path: "tensorflow.TensorSpec"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.tensor_spec.TensorSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
@ -14,6 +16,10 @@ tf_class {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\'], "
|
||||
@ -30,4 +36,8 @@ tf_class {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_tensor\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,15 @@
|
||||
path: "tensorflow.data.experimental.DatasetStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'element_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'self\', \'element_spec\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
@ -13,6 +17,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,13 @@
|
||||
path: "tensorflow.data.experimental.NestedStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.NestedStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'nested_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -15,4 +20,8 @@ tf_class {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,8 +1,12 @@
|
||||
path: "tensorflow.data.experimental.OptionalStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.ops.optional_ops.OptionalStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'value_structure\'], varargs=None, keywords=None, defaults=None"
|
||||
@ -13,6 +17,10 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.RaggedTensorStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.RaggedTensorStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.SparseTensorStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.SparseTensorStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,16 +1,20 @@
|
||||
path: "tensorflow.data.experimental.Structure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<class \'abc.abstractproperty\'>"
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'spec_or_value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
name: "most_specific_compatible_type"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
||||
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.TensorArrayStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.TensorArrayStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -1,18 +0,0 @@
|
||||
path: "tensorflow.data.experimental.TensorStructure"
|
||||
tf_class {
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.TensorStructure\'>"
|
||||
is_instance: "<class \'tensorflow.python.data.util.structure.Structure\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "is_compatible_with"
|
||||
argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
}
|
@ -44,10 +44,6 @@ tf_module {
|
||||
name: "OptionalStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RaggedTensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "RandomDataset"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -56,10 +52,6 @@ tf_module {
|
||||
name: "Reducer"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "SparseTensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "SqlDataset"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -80,14 +72,6 @@ tf_module {
|
||||
name: "TFRecordWriter"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorArrayStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "TensorStructure"
|
||||
mtype: "<type \'type\'>"
|
||||
}
|
||||
member {
|
||||
name: "ThreadingOptions"
|
||||
mtype: "<type \'type\'>"
|
||||
@ -100,6 +84,22 @@ tf_module {
|
||||
name: "Counter"
|
||||
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\', \'ragged_rank\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "SparseTensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "TensorArrayStructure"
|
||||
argspec: "args=[\'dtype\', \'element_shape\', \'dynamic_size\', \'infer_shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "TensorStructure"
|
||||
argspec: "args=[\'dtype\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "bucket_by_sequence_length"
|
||||
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\', \'False\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user