Add support for RaggedTensor properties, instance methods, and class methods in the Keras functional API.
PiperOrigin-RevId: 337337772 Change-Id: I4b66330078049b13ef5c8eddfbb77d895a65a9d8
This commit is contained in:
parent
df6a423036
commit
9adbacfd41
tensorflow/python
@ -20,15 +20,20 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import func_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.keras import keras_parameterized
|
||||
from tensorflow.python.keras import layers
|
||||
from tensorflow.python.keras import testing_utils
|
||||
from tensorflow.python.keras.engine import training
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
class RaggedKerasTensorTest(keras_parameterized.TestCase):
|
||||
@ -89,6 +94,278 @@ class RaggedKerasTensorTest(keras_parameterized.TestCase):
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
self.assertAllEqual(model(x), x / x)
|
||||
|
||||
@parameterized.parameters(
|
||||
{'property_name': 'values'},
|
||||
{'property_name': 'flat_values'},
|
||||
{'property_name': 'row_splits'},
|
||||
{'property_name': 'nested_row_splits'},
|
||||
)
|
||||
def test_instance_property(self, property_name):
|
||||
inp = layers.Input(shape=[None], ragged=True)
|
||||
out = getattr(inp, property_name)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
expected_property = getattr(x, property_name)
|
||||
self.assertAllEqual(model(x), expected_property)
|
||||
|
||||
# Test that it works with serialization and deserialization as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected_property)
|
||||
|
||||
@parameterized.parameters(
|
||||
{'name': 'value_rowids'},
|
||||
{'name': 'nested_value_rowids'},
|
||||
{'name': 'nrows'},
|
||||
{'name': 'row_starts'},
|
||||
{'name': 'row_limits'},
|
||||
{'name': 'row_lengths'},
|
||||
{'name': 'nested_row_lengths'},
|
||||
{'name': 'bounding_shape'},
|
||||
{
|
||||
'name': 'with_values',
|
||||
'args': [[1, 2, 3, 4, 5, 6]]
|
||||
},
|
||||
{
|
||||
'name': 'with_flat_values',
|
||||
'kwargs': {
|
||||
'new_values': [1, 2, 3, 4, 5, 6]
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'with_row_splits_dtype',
|
||||
'kwargs': {
|
||||
'dtype': dtypes.int32
|
||||
}
|
||||
},
|
||||
{
|
||||
'name': 'merge_dims',
|
||||
'args': [0],
|
||||
'kwargs': {
|
||||
'inner_axis': 1
|
||||
}
|
||||
},
|
||||
{'name': 'to_tensor'},
|
||||
{'name': 'to_sparse'},
|
||||
)
|
||||
def test_instance_method(self, name, args=None, kwargs=None):
|
||||
if not args:
|
||||
args = []
|
||||
if not kwargs:
|
||||
kwargs = {}
|
||||
|
||||
inp = layers.Input(shape=[None], ragged=True)
|
||||
out = getattr(inp, name)(*args, **kwargs)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
|
||||
expected_property = getattr(x, name)(*args, **kwargs)
|
||||
# We expand composites before checking equality because
|
||||
# assertAllEqual otherwise wouldn't work for SparseTensor outputs
|
||||
for a, b in zip(nest.flatten(model(x), expand_composites=True),
|
||||
nest.flatten(expected_property, expand_composites=True)):
|
||||
self.assertAllEqual(a, b)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
for a, b in zip(nest.flatten(model2(x), expand_composites=True),
|
||||
nest.flatten(expected_property, expand_composites=True)):
|
||||
self.assertAllEqual(a, b)
|
||||
|
||||
|
||||
class RaggedTensorClassMethodAsLayerTest(keras_parameterized.TestCase):
|
||||
|
||||
def test_from_value_rowids(self):
|
||||
inp = layers.Input(shape=[None])
|
||||
out = ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
inp, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
expected = ragged_tensor.RaggedTensor.from_value_rowids(
|
||||
x, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_row_splits(self):
|
||||
inp = layers.Input(shape=[None])
|
||||
out = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
inp, row_splits=[0, 4, 4, 7, 8, 8])
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
expected = ragged_tensor.RaggedTensor.from_row_splits(
|
||||
x, row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_row_lengths(self):
|
||||
inp = layers.Input(shape=[None])
|
||||
out = ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
inp, row_lengths=[4, 0, 3, 1, 0])
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
expected = ragged_tensor.RaggedTensor.from_row_lengths(
|
||||
x, row_lengths=[4, 0, 3, 1, 0])
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_row_starts(self):
|
||||
inp = layers.Input(shape=[None])
|
||||
out = ragged_tensor.RaggedTensor.from_row_starts(
|
||||
inp, row_starts=[0, 4, 4, 7, 8])
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
|
||||
expected = ragged_tensor.RaggedTensor.from_row_starts(
|
||||
x, row_starts=[0, 4, 4, 7, 8])
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_row_limits(self):
|
||||
row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
|
||||
|
||||
inp = layers.Input(shape=[None], dtype=dtypes.string)
|
||||
out = ragged_tensor.RaggedTensor.from_row_limits(
|
||||
inp, row_limits, validate=False)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
expected = ragged_tensor.RaggedTensor.from_row_limits(
|
||||
x, row_limits, validate=False)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_uniform_row_length(self):
|
||||
inp = layers.Input(shape=[None])
|
||||
out = ragged_tensor.RaggedTensor.from_uniform_row_length(inp, 2, 8)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant(
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
|
||||
expected = ragged_tensor.RaggedTensor.from_uniform_row_length(x, 2, 8)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_nested_value_row_ids(self):
|
||||
nested_value_rowids = [
|
||||
constant_op.constant([0, 0, 1, 3, 3], dtypes.int64),
|
||||
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
]
|
||||
inp = layers.Input(shape=[None], dtype=dtypes.string)
|
||||
out = ragged_tensor.RaggedTensor.from_nested_value_rowids(
|
||||
inp, nested_value_rowids)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
expected = ragged_tensor.RaggedTensor.from_nested_value_rowids(
|
||||
x, nested_value_rowids)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_nested_row_splits(self):
|
||||
nested_row_splits = [
|
||||
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
|
||||
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
]
|
||||
inp = layers.Input(shape=[None], dtype=dtypes.string)
|
||||
out = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
inp, nested_row_splits)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
expected = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
x, nested_row_splits)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_nested_row_lengths(self):
|
||||
nested_row_lengths = [
|
||||
constant_op.constant([2, 1, 0, 2], dtypes.int64),
|
||||
constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
|
||||
]
|
||||
inp = layers.Input(shape=[None], dtype=dtypes.string)
|
||||
out = ragged_tensor.RaggedTensor.from_nested_row_lengths(
|
||||
inp, nested_row_lengths)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
expected = ragged_tensor.RaggedTensor.from_nested_row_lengths(
|
||||
x, nested_row_lengths)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_tensor(self):
|
||||
inp = layers.Input(shape=[None], ragged=False)
|
||||
out = ragged_tensor.RaggedTensor.from_tensor(inp)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
x = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]])
|
||||
expected = ragged_tensor.RaggedTensor.from_tensor(x)
|
||||
self.assertAllEqual(model(x), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(x), expected)
|
||||
|
||||
def test_from_sparse(self):
|
||||
inp = layers.Input(shape=[None], sparse=True, dtype=dtypes.string)
|
||||
out = ragged_tensor.RaggedTensor.from_sparse(inp)
|
||||
model = training.Model(inp, out)
|
||||
|
||||
indices = [[0, 0], [1, 0], [1, 1], [2, 0]]
|
||||
values = [b'a', b'b', b'c', b'd']
|
||||
shape = [4, 5]
|
||||
sp_value = sparse_tensor.SparseTensor(indices, values, shape)
|
||||
|
||||
expected = ragged_tensor.RaggedTensor.from_sparse(sp_value)
|
||||
self.assertAllEqual(model(sp_value), expected)
|
||||
|
||||
# Test that the model can serialize and deserialize as well
|
||||
model_config = model.get_config()
|
||||
model2 = training.Model.from_config(model_config)
|
||||
self.assertAllEqual(model2(sp_value), expected)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
@ -54,6 +54,7 @@ from tensorflow.python.ops import gen_array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import tf_logging
|
||||
from tensorflow.python.training.tracking import base as trackable
|
||||
from tensorflow.python.util import dispatch
|
||||
@ -1542,3 +1543,261 @@ for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access
|
||||
array_ops.boolean_mask,
|
||||
array_ops.boolean_mask_v2]:
|
||||
TFSlicingOpDispatcher(slicing_op).register(slicing_op)
|
||||
|
||||
|
||||
class InstanceProperty(Layer):
|
||||
"""Wraps an instance property access (e.g. `x.foo`) in a Keras Layer.
|
||||
|
||||
This layer takes an attribute name `attr_name` in the constructor and,
|
||||
when called on input tensor `obj` returns `obj.attr_name`.
|
||||
|
||||
KerasTensors specialized for specific extension types use it to
|
||||
represent instance property accesses on the represented object in the
|
||||
case where the property needs to be dynamically accessed as opposed to
|
||||
being statically computed from the typespec, e.g.
|
||||
|
||||
x = keras.Input(..., ragged=True)
|
||||
out = x.flat_values
|
||||
"""
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, attr_name, **kwargs):
|
||||
self.attr_name = attr_name
|
||||
|
||||
if 'name' not in kwargs:
|
||||
kwargs['name'] = K.unique_object_name(
|
||||
'input.' + self.attr_name, zero_based=True, avoid_observed_names=True)
|
||||
kwargs['autocast'] = False
|
||||
|
||||
# Do not individually trace op layers in the SavedModel.
|
||||
self._must_restore_from_config = True
|
||||
|
||||
super(InstanceProperty, self).__init__(**kwargs)
|
||||
|
||||
# Preserve all argument data structures when saving/loading a config
|
||||
# (e.g., don't unnest lists that contain one element)
|
||||
self._preserve_input_structure_in_config = True
|
||||
|
||||
def call(self, obj):
|
||||
return getattr(obj, self.attr_name)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
'attr_name': self.attr_name
|
||||
}
|
||||
base_config = super(InstanceProperty, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class InstanceMethod(InstanceProperty):
|
||||
"""Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer.
|
||||
|
||||
This layer takes an attribute name `attr_name` in the constructor and,
|
||||
when called on input tensor `obj` with additional arguments `args` and
|
||||
`kwargs` returns `obj.attr_name(*args, **kwargs)`.
|
||||
|
||||
KerasTensors specialized for specific extension types use it to
|
||||
represent dynamic instance method calls on the represented object, e.g.
|
||||
|
||||
x = keras.Input(..., ragged=True)
|
||||
new_values = keras.Input(...)
|
||||
out = x.with_values(new_values)
|
||||
"""
|
||||
|
||||
def call(self, obj, args, kwargs):
|
||||
method = getattr(obj, self.attr_name)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
||||
def _delegate_property(keras_tensor_cls, property_name): # pylint: disable=invalid-name
|
||||
"""Register property on a KerasTensor class.
|
||||
|
||||
Calling this multiple times with the same arguments should be a no-op.
|
||||
|
||||
This method exposes a property on the KerasTensor class that will use an
|
||||
`InstanceProperty` layer to access the property on the represented
|
||||
intermediate values in the model.
|
||||
|
||||
Arguments:
|
||||
keras_tensor_cls: The KerasTensor subclass that should expose the property.
|
||||
property_name: The name of the property to expose and delegate to the
|
||||
represented (Composite)Tensor.
|
||||
"""
|
||||
# We use a lambda because we can't create a Keras layer at import time
|
||||
# due to dynamic layer class versioning.
|
||||
property_access = property(lambda self: InstanceProperty(property_name)(self)) # pylint: disable=unnecessary-lambda
|
||||
setattr(keras_tensor_cls, property_name, property_access)
|
||||
|
||||
|
||||
def _delegate_method(keras_tensor_cls, method_name): # pylint: disable=invalid-name
|
||||
"""Register method on a KerasTensor class.
|
||||
|
||||
Calling this function times with the same arguments should be a no-op.
|
||||
|
||||
This method exposes an instance method on the KerasTensor class that will use
|
||||
an `InstanceMethod` layer to run the desired method on the represented
|
||||
intermediate values in the model.
|
||||
|
||||
Arguments:
|
||||
keras_tensor_cls: The KerasTensor subclass that should expose the property.
|
||||
method_name: The name of the method to expose and delegate to the
|
||||
represented (Composite)Tensor.
|
||||
"""
|
||||
def delegate(self, *args, **kwargs):
|
||||
return InstanceMethod(method_name)(self, args, kwargs)
|
||||
setattr(keras_tensor_cls, method_name, delegate)
|
||||
|
||||
# We do not support the `uniform_row_length` property because it
|
||||
# returns either `None` or an int tensor, and code that relies on it tends
|
||||
# to check `is None` directly. Delegating it here would always return a
|
||||
# `KerasTensor`, regardless of what can be statically inferred. This would
|
||||
# never equal `None`, breaking code that expects it to be partially-static
|
||||
# in unpredictable ways.
|
||||
for ragged_property in [
|
||||
'values',
|
||||
'flat_values',
|
||||
'row_splits',
|
||||
'nested_row_splits'
|
||||
]:
|
||||
_delegate_property(keras_tensor.RaggedKerasTensor, ragged_property)
|
||||
|
||||
for ragged_method_name in [
|
||||
'value_rowids',
|
||||
'nested_value_rowids',
|
||||
'nrows',
|
||||
'row_starts',
|
||||
'row_limits',
|
||||
'row_lengths',
|
||||
'nested_row_lengths',
|
||||
'bounding_shape',
|
||||
'with_values',
|
||||
'with_flat_values',
|
||||
'with_row_splits_dtype',
|
||||
'merge_dims',
|
||||
'to_tensor',
|
||||
'to_sparse',
|
||||
]:
|
||||
_delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name)
|
||||
|
||||
for sparse_property in [
|
||||
'indices',
|
||||
'values',
|
||||
]:
|
||||
_delegate_property(keras_tensor.SparseKerasTensor, sparse_property)
|
||||
|
||||
for sparse_method in [
|
||||
'with_values',
|
||||
]:
|
||||
_delegate_method(keras_tensor.SparseKerasTensor, sparse_method)
|
||||
|
||||
|
||||
class ClassMethod(Layer):
|
||||
"""Wraps a TF API Class's class method in a `Layer` object.
|
||||
|
||||
It is inserted by the Functional API construction whenever users call
|
||||
a supported TF Class's class method on KerasTensors.
|
||||
|
||||
This is useful in the case where users do something like:
|
||||
x = keras.Input(...)
|
||||
y = keras.Input(...)
|
||||
out = tf.RaggedTensor.from_row_splits(x, y)
|
||||
"""
|
||||
|
||||
@trackable.no_automatic_dependency_tracking
|
||||
def __init__(self, cls_ref, method_name, **kwargs):
|
||||
self.cls_ref = cls_ref
|
||||
self.method_name = method_name
|
||||
self.cls_symbol = (
|
||||
get_canonical_name_for_symbol(self.cls_ref,
|
||||
add_prefix_to_v1_names=True) or
|
||||
get_canonical_name_for_symbol(self.cls_ref,
|
||||
api_name='keras',
|
||||
add_prefix_to_v1_names=True))
|
||||
if 'name' not in kwargs:
|
||||
kwargs['name'] = K.unique_object_name(
|
||||
'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True,
|
||||
avoid_observed_names=True)
|
||||
kwargs['autocast'] = False
|
||||
|
||||
# Do not individually trace op layers in the SavedModel.
|
||||
self._must_restore_from_config = True
|
||||
|
||||
super(ClassMethod, self).__init__(**kwargs)
|
||||
|
||||
# Preserve all argument data structures when saving/loading a config
|
||||
# (e.g., don't unnest lists that contain one element)
|
||||
self._preserve_input_structure_in_config = True
|
||||
|
||||
self._expects_training_arg = False
|
||||
self._expects_mask_arg = False
|
||||
|
||||
def call(self, args, kwargs):
|
||||
return getattr(self.cls_ref, self.method_name)(*args, **kwargs)
|
||||
|
||||
def get_config(self):
|
||||
if not self.cls_symbol:
|
||||
raise ValueError('This Keras class method conversion tried to convert '
|
||||
'a method belonging to class %s, a class '
|
||||
'that is not an exposed in the TensorFlow API. '
|
||||
'To ensure cross-version compatibility of Keras models '
|
||||
'that use op layers, only op layers produced from '
|
||||
'exported TF API symbols can be serialized.'
|
||||
% self.cls_symbol)
|
||||
config = {
|
||||
'cls_symbol': self.cls_symbol,
|
||||
'method_name': self.method_name
|
||||
}
|
||||
|
||||
base_config = super(ClassMethod, self).get_config()
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
config = config.copy()
|
||||
symbol_name = config.pop('cls_symbol')
|
||||
cls_ref = get_symbol_from_name(symbol_name)
|
||||
if not cls_ref:
|
||||
raise ValueError(
|
||||
'TF symbol `tf.%s` could not be found.' % symbol_name)
|
||||
|
||||
config['cls_ref'] = cls_ref
|
||||
|
||||
return cls(**config)
|
||||
|
||||
|
||||
class TFClassMethodDispatcher(dispatch.OpDispatcher):
|
||||
"""A class method dispatcher that allows building a functional model with TF class methods."""
|
||||
|
||||
def __init__(self, cls, method_name):
|
||||
self.cls = cls
|
||||
self.method_name = method_name
|
||||
|
||||
def handle(self, args, kwargs):
|
||||
"""Handle the specified operation with the specified arguments."""
|
||||
if any(
|
||||
isinstance(x, keras_tensor.KerasTensor)
|
||||
for x in nest.flatten([args, kwargs])):
|
||||
return ClassMethod(self.cls, self.method_name)(args[1:], kwargs)
|
||||
else:
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
for ragged_class_method in [
|
||||
'from_value_rowids',
|
||||
'from_row_splits',
|
||||
'from_row_lengths',
|
||||
'from_row_starts',
|
||||
'from_row_limits',
|
||||
'from_uniform_row_length',
|
||||
'from_nested_value_rowids',
|
||||
'from_nested_row_splits',
|
||||
'from_nested_row_lengths',
|
||||
'from_tensor',
|
||||
'from_sparse',
|
||||
]:
|
||||
TFClassMethodDispatcher(
|
||||
ragged_tensor.RaggedTensor, ragged_class_method).register(
|
||||
getattr(ragged_tensor.RaggedTensor, ragged_class_method))
|
||||
|
@ -44,6 +44,7 @@ from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged import ragged_util
|
||||
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
||||
from tensorflow.python.types import internal as internal_types
|
||||
from tensorflow.python.util import dispatch
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
from tensorflow.tools.docs import doc_controls
|
||||
|
||||
@ -341,6 +342,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
row_partition=row_partition)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_value_rowids(cls,
|
||||
values,
|
||||
value_rowids,
|
||||
@ -399,6 +401,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return cls._from_row_partition(values, row_partition, validate=validate)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_row_splits(cls, values, row_splits, name=None, validate=True):
|
||||
"""Creates a `RaggedTensor` with rows partitioned by `row_splits`.
|
||||
|
||||
@ -445,6 +448,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return cls._from_row_partition(values, row_partition, validate=validate)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_row_lengths(cls, values, row_lengths, name=None, validate=True):
|
||||
"""Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
|
||||
|
||||
@ -487,6 +491,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return cls._from_row_partition(values, row_partition, validate=validate)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_row_starts(cls, values, row_starts, name=None, validate=True):
|
||||
"""Creates a `RaggedTensor` with rows partitioned by `row_starts`.
|
||||
|
||||
@ -526,6 +531,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return cls._from_row_partition(values, row_partition, validate=validate)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_row_limits(cls, values, row_limits, name=None, validate=True):
|
||||
"""Creates a `RaggedTensor` with rows partitioned by `row_limits`.
|
||||
|
||||
@ -562,6 +568,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return cls._from_row_partition(values, row_partition, validate=validate)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_uniform_row_length(cls,
|
||||
values,
|
||||
uniform_row_length,
|
||||
@ -636,6 +643,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return cls._from_row_partition(values, row_partition, validate=validate)
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_nested_value_rowids(cls,
|
||||
flat_values,
|
||||
nested_value_rowids,
|
||||
@ -692,6 +700,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_nested_row_splits(cls,
|
||||
flat_values,
|
||||
nested_row_splits,
|
||||
@ -731,6 +740,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_nested_row_lengths(cls,
|
||||
flat_values,
|
||||
nested_row_lengths,
|
||||
@ -1307,6 +1317,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
A `RaggedTensor`. `result.rank = 1 + new_values.rank`.
|
||||
`result.ragged_rank = 1 + new_values.ragged_rank`
|
||||
"""
|
||||
new_values = _convert_to_ragged_tensor_values(new_values)
|
||||
new_values.shape.with_rank_at_least(1)
|
||||
self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
|
||||
if (isinstance(new_values, RaggedTensor) and
|
||||
@ -1339,8 +1350,8 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
if isinstance(self._values, RaggedTensor):
|
||||
return self.with_values(self.values.with_flat_values(new_values))
|
||||
else:
|
||||
_assert_is_supported_ragged_values_type(new_values)
|
||||
return self.with_values(new_values)
|
||||
new_values = _convert_to_ragged_tensor_values(new_values)
|
||||
return self.with_values(new_values)
|
||||
|
||||
def with_row_splits_dtype(self, dtype):
|
||||
"""Returns a copy of this RaggedTensor with the given `row_splits` dtype.
|
||||
@ -1479,6 +1490,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
#=============================================================================
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_tensor(cls,
|
||||
tensor,
|
||||
lengths=None,
|
||||
@ -1751,6 +1763,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
return tensor
|
||||
|
||||
@classmethod
|
||||
@dispatch.add_dispatch_support
|
||||
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
|
||||
"""Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`.
|
||||
|
||||
@ -2521,8 +2534,8 @@ def convert_to_tensor_or_ragged_tensor(value,
|
||||
return RaggedTensor.from_nested_row_splits(
|
||||
flat_values, value.nested_row_splits, validate=False)
|
||||
else:
|
||||
return ops.convert_to_tensor(
|
||||
value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name)
|
||||
return ops.convert_to_tensor_v2_with_dispatch(
|
||||
value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name)
|
||||
|
||||
|
||||
def _convert_to_ragged_tensor_values(value):
|
||||
|
Loading…
Reference in New Issue
Block a user