diff --git a/tensorflow/python/keras/engine/ragged_keras_tensor_test.py b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py index 92abdc82240..fc85fef29bf 100644 --- a/tensorflow/python/keras/engine/ragged_keras_tensor_test.py +++ b/tensorflow/python/keras/engine/ragged_keras_tensor_test.py @@ -20,15 +20,20 @@ from __future__ import print_function from absl.testing import parameterized +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import func_graph from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import training from tensorflow.python.ops.ragged import ragged_factory_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import test +from tensorflow.python.util import nest class RaggedKerasTensorTest(keras_parameterized.TestCase): @@ -89,6 +94,278 @@ class RaggedKerasTensorTest(keras_parameterized.TestCase): x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) self.assertAllEqual(model(x), x / x) + @parameterized.parameters( + {'property_name': 'values'}, + {'property_name': 'flat_values'}, + {'property_name': 'row_splits'}, + {'property_name': 'nested_row_splits'}, + ) + def test_instance_property(self, property_name): + inp = layers.Input(shape=[None], ragged=True) + out = getattr(inp, property_name) + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + expected_property = getattr(x, property_name) + self.assertAllEqual(model(x), expected_property) + + # Test that it works with serialization and deserialization as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected_property) + + @parameterized.parameters( + {'name': 'value_rowids'}, + {'name': 'nested_value_rowids'}, + {'name': 'nrows'}, + {'name': 'row_starts'}, + {'name': 'row_limits'}, + {'name': 'row_lengths'}, + {'name': 'nested_row_lengths'}, + {'name': 'bounding_shape'}, + { + 'name': 'with_values', + 'args': [[1, 2, 3, 4, 5, 6]] + }, + { + 'name': 'with_flat_values', + 'kwargs': { + 'new_values': [1, 2, 3, 4, 5, 6] + } + }, + { + 'name': 'with_row_splits_dtype', + 'kwargs': { + 'dtype': dtypes.int32 + } + }, + { + 'name': 'merge_dims', + 'args': [0], + 'kwargs': { + 'inner_axis': 1 + } + }, + {'name': 'to_tensor'}, + {'name': 'to_sparse'}, + ) + def test_instance_method(self, name, args=None, kwargs=None): + if not args: + args = [] + if not kwargs: + kwargs = {} + + inp = layers.Input(shape=[None], ragged=True) + out = getattr(inp, name)(*args, **kwargs) + model = training.Model(inp, out) + + x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]]) + expected_property = getattr(x, name)(*args, **kwargs) + # We expand composites before checking equality because + # assertAllEqual otherwise wouldn't work for SparseTensor outputs + for a, b in zip(nest.flatten(model(x), expand_composites=True), + nest.flatten(expected_property, expand_composites=True)): + self.assertAllEqual(a, b) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + for a, b in zip(nest.flatten(model2(x), expand_composites=True), + nest.flatten(expected_property, expand_composites=True)): + self.assertAllEqual(a, b) + + +class RaggedTensorClassMethodAsLayerTest(keras_parameterized.TestCase): + + def test_from_value_rowids(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_value_rowids( + inp, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_value_rowids( + x, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_splits(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_row_splits( + inp, row_splits=[0, 4, 4, 7, 8, 8]) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_row_splits( + x, row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_lengths(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_row_lengths( + inp, row_lengths=[4, 0, 3, 1, 0]) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_row_lengths( + x, row_lengths=[4, 0, 3, 1, 0]) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_starts(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_row_starts( + inp, row_starts=[0, 4, 4, 7, 8]) + model = training.Model(inp, out) + + x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6]) + expected = ragged_tensor.RaggedTensor.from_row_starts( + x, row_starts=[0, 4, 4, 7, 8]) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_row_limits(self): + row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) + + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_row_limits( + inp, row_limits, validate=False) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_row_limits( + x, row_limits, validate=False) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_uniform_row_length(self): + inp = layers.Input(shape=[None]) + out = ragged_tensor.RaggedTensor.from_uniform_row_length(inp, 2, 8) + model = training.Model(inp, out) + + x = constant_op.constant( + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]) + expected = ragged_tensor.RaggedTensor.from_uniform_row_length(x, 2, 8) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_nested_value_row_ids(self): + nested_value_rowids = [ + constant_op.constant([0, 0, 1, 3, 3], dtypes.int64), + constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + ] + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_nested_value_rowids( + inp, nested_value_rowids) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_nested_value_rowids( + x, nested_value_rowids) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_nested_row_splits(self): + nested_row_splits = [ + constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), + constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + ] + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_nested_row_splits( + inp, nested_row_splits) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_nested_row_splits( + x, nested_row_splits) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_nested_row_lengths(self): + nested_row_lengths = [ + constant_op.constant([2, 1, 0, 2], dtypes.int64), + constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) + ] + inp = layers.Input(shape=[None], dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_nested_row_lengths( + inp, nested_row_lengths) + model = training.Model(inp, out) + + x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + expected = ragged_tensor.RaggedTensor.from_nested_row_lengths( + x, nested_row_lengths) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_tensor(self): + inp = layers.Input(shape=[None], ragged=False) + out = ragged_tensor.RaggedTensor.from_tensor(inp) + model = training.Model(inp, out) + + x = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]]) + expected = ragged_tensor.RaggedTensor.from_tensor(x) + self.assertAllEqual(model(x), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(x), expected) + + def test_from_sparse(self): + inp = layers.Input(shape=[None], sparse=True, dtype=dtypes.string) + out = ragged_tensor.RaggedTensor.from_sparse(inp) + model = training.Model(inp, out) + + indices = [[0, 0], [1, 0], [1, 1], [2, 0]] + values = [b'a', b'b', b'c', b'd'] + shape = [4, 5] + sp_value = sparse_tensor.SparseTensor(indices, values, shape) + + expected = ragged_tensor.RaggedTensor.from_sparse(sp_value) + self.assertAllEqual(model(sp_value), expected) + + # Test that the model can serialize and deserialize as well + model_config = model.get_config() + model2 = training.Model.from_config(model_config) + self.assertAllEqual(model2(sp_value), expected) + if __name__ == '__main__': ops.enable_eager_execution() diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py index 40d5846745a..6772aba605e 100644 --- a/tensorflow/python/keras/layers/core.py +++ b/tensorflow/python/keras/layers/core.py @@ -54,6 +54,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import tf_logging from tensorflow.python.training.tracking import base as trackable from tensorflow.python.util import dispatch @@ -1542,3 +1543,261 @@ for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access array_ops.boolean_mask, array_ops.boolean_mask_v2]: TFSlicingOpDispatcher(slicing_op).register(slicing_op) + + +class InstanceProperty(Layer): + """Wraps an instance property access (e.g. `x.foo`) in a Keras Layer. + + This layer takes an attribute name `attr_name` in the constructor and, + when called on input tensor `obj` returns `obj.attr_name`. + + KerasTensors specialized for specific extension types use it to + represent instance property accesses on the represented object in the + case where the property needs to be dynamically accessed as opposed to + being statically computed from the typespec, e.g. + + x = keras.Input(..., ragged=True) + out = x.flat_values + """ + + @trackable.no_automatic_dependency_tracking + def __init__(self, attr_name, **kwargs): + self.attr_name = attr_name + + if 'name' not in kwargs: + kwargs['name'] = K.unique_object_name( + 'input.' + self.attr_name, zero_based=True, avoid_observed_names=True) + kwargs['autocast'] = False + + # Do not individually trace op layers in the SavedModel. + self._must_restore_from_config = True + + super(InstanceProperty, self).__init__(**kwargs) + + # Preserve all argument data structures when saving/loading a config + # (e.g., don't unnest lists that contain one element) + self._preserve_input_structure_in_config = True + + def call(self, obj): + return getattr(obj, self.attr_name) + + def get_config(self): + config = { + 'attr_name': self.attr_name + } + base_config = super(InstanceProperty, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + return cls(**config) + + +class InstanceMethod(InstanceProperty): + """Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer. + + This layer takes an attribute name `attr_name` in the constructor and, + when called on input tensor `obj` with additional arguments `args` and + `kwargs` returns `obj.attr_name(*args, **kwargs)`. + + KerasTensors specialized for specific extension types use it to + represent dynamic instance method calls on the represented object, e.g. + + x = keras.Input(..., ragged=True) + new_values = keras.Input(...) + out = x.with_values(new_values) + """ + + def call(self, obj, args, kwargs): + method = getattr(obj, self.attr_name) + return method(*args, **kwargs) + + +def _delegate_property(keras_tensor_cls, property_name): # pylint: disable=invalid-name + """Register property on a KerasTensor class. + + Calling this multiple times with the same arguments should be a no-op. + + This method exposes a property on the KerasTensor class that will use an + `InstanceProperty` layer to access the property on the represented + intermediate values in the model. + + Arguments: + keras_tensor_cls: The KerasTensor subclass that should expose the property. + property_name: The name of the property to expose and delegate to the + represented (Composite)Tensor. + """ + # We use a lambda because we can't create a Keras layer at import time + # due to dynamic layer class versioning. + property_access = property(lambda self: InstanceProperty(property_name)(self)) # pylint: disable=unnecessary-lambda + setattr(keras_tensor_cls, property_name, property_access) + + +def _delegate_method(keras_tensor_cls, method_name): # pylint: disable=invalid-name + """Register method on a KerasTensor class. + + Calling this function times with the same arguments should be a no-op. + + This method exposes an instance method on the KerasTensor class that will use + an `InstanceMethod` layer to run the desired method on the represented + intermediate values in the model. + + Arguments: + keras_tensor_cls: The KerasTensor subclass that should expose the property. + method_name: The name of the method to expose and delegate to the + represented (Composite)Tensor. + """ + def delegate(self, *args, **kwargs): + return InstanceMethod(method_name)(self, args, kwargs) + setattr(keras_tensor_cls, method_name, delegate) + +# We do not support the `uniform_row_length` property because it +# returns either `None` or an int tensor, and code that relies on it tends +# to check `is None` directly. Delegating it here would always return a +# `KerasTensor`, regardless of what can be statically inferred. This would +# never equal `None`, breaking code that expects it to be partially-static +# in unpredictable ways. +for ragged_property in [ + 'values', + 'flat_values', + 'row_splits', + 'nested_row_splits' +]: + _delegate_property(keras_tensor.RaggedKerasTensor, ragged_property) + +for ragged_method_name in [ + 'value_rowids', + 'nested_value_rowids', + 'nrows', + 'row_starts', + 'row_limits', + 'row_lengths', + 'nested_row_lengths', + 'bounding_shape', + 'with_values', + 'with_flat_values', + 'with_row_splits_dtype', + 'merge_dims', + 'to_tensor', + 'to_sparse', +]: + _delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name) + +for sparse_property in [ + 'indices', + 'values', +]: + _delegate_property(keras_tensor.SparseKerasTensor, sparse_property) + +for sparse_method in [ + 'with_values', +]: + _delegate_method(keras_tensor.SparseKerasTensor, sparse_method) + + +class ClassMethod(Layer): + """Wraps a TF API Class's class method in a `Layer` object. + + It is inserted by the Functional API construction whenever users call + a supported TF Class's class method on KerasTensors. + + This is useful in the case where users do something like: + x = keras.Input(...) + y = keras.Input(...) + out = tf.RaggedTensor.from_row_splits(x, y) + """ + + @trackable.no_automatic_dependency_tracking + def __init__(self, cls_ref, method_name, **kwargs): + self.cls_ref = cls_ref + self.method_name = method_name + self.cls_symbol = ( + get_canonical_name_for_symbol(self.cls_ref, + add_prefix_to_v1_names=True) or + get_canonical_name_for_symbol(self.cls_ref, + api_name='keras', + add_prefix_to_v1_names=True)) + if 'name' not in kwargs: + kwargs['name'] = K.unique_object_name( + 'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True, + avoid_observed_names=True) + kwargs['autocast'] = False + + # Do not individually trace op layers in the SavedModel. + self._must_restore_from_config = True + + super(ClassMethod, self).__init__(**kwargs) + + # Preserve all argument data structures when saving/loading a config + # (e.g., don't unnest lists that contain one element) + self._preserve_input_structure_in_config = True + + self._expects_training_arg = False + self._expects_mask_arg = False + + def call(self, args, kwargs): + return getattr(self.cls_ref, self.method_name)(*args, **kwargs) + + def get_config(self): + if not self.cls_symbol: + raise ValueError('This Keras class method conversion tried to convert ' + 'a method belonging to class %s, a class ' + 'that is not an exposed in the TensorFlow API. ' + 'To ensure cross-version compatibility of Keras models ' + 'that use op layers, only op layers produced from ' + 'exported TF API symbols can be serialized.' + % self.cls_symbol) + config = { + 'cls_symbol': self.cls_symbol, + 'method_name': self.method_name + } + + base_config = super(ClassMethod, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + @classmethod + def from_config(cls, config, custom_objects=None): + config = config.copy() + symbol_name = config.pop('cls_symbol') + cls_ref = get_symbol_from_name(symbol_name) + if not cls_ref: + raise ValueError( + 'TF symbol `tf.%s` could not be found.' % symbol_name) + + config['cls_ref'] = cls_ref + + return cls(**config) + + +class TFClassMethodDispatcher(dispatch.OpDispatcher): + """A class method dispatcher that allows building a functional model with TF class methods.""" + + def __init__(self, cls, method_name): + self.cls = cls + self.method_name = method_name + + def handle(self, args, kwargs): + """Handle the specified operation with the specified arguments.""" + if any( + isinstance(x, keras_tensor.KerasTensor) + for x in nest.flatten([args, kwargs])): + return ClassMethod(self.cls, self.method_name)(args[1:], kwargs) + else: + return self.NOT_SUPPORTED + +for ragged_class_method in [ + 'from_value_rowids', + 'from_row_splits', + 'from_row_lengths', + 'from_row_starts', + 'from_row_limits', + 'from_uniform_row_length', + 'from_nested_value_rowids', + 'from_nested_row_splits', + 'from_nested_row_lengths', + 'from_tensor', + 'from_sparse', +]: + TFClassMethodDispatcher( + ragged_tensor.RaggedTensor, ragged_class_method).register( + getattr(ragged_tensor.RaggedTensor, ragged_class_method)) diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 800272d0dd9..2c2d2fdd3ad 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -44,6 +44,7 @@ from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_util from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.types import internal as internal_types +from tensorflow.python.util import dispatch from tensorflow.python.util.tf_export import tf_export from tensorflow.tools.docs import doc_controls @@ -341,6 +342,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, row_partition=row_partition) @classmethod + @dispatch.add_dispatch_support def from_value_rowids(cls, values, value_rowids, @@ -399,6 +401,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_splits(cls, values, row_splits, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_splits`. @@ -445,6 +448,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_lengths(cls, values, row_lengths, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_lengths`. @@ -487,6 +491,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_starts(cls, values, row_starts, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_starts`. @@ -526,6 +531,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_row_limits(cls, values, row_limits, name=None, validate=True): """Creates a `RaggedTensor` with rows partitioned by `row_limits`. @@ -562,6 +568,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_uniform_row_length(cls, values, uniform_row_length, @@ -636,6 +643,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return cls._from_row_partition(values, row_partition, validate=validate) @classmethod + @dispatch.add_dispatch_support def from_nested_value_rowids(cls, flat_values, nested_value_rowids, @@ -692,6 +700,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return result @classmethod + @dispatch.add_dispatch_support def from_nested_row_splits(cls, flat_values, nested_row_splits, @@ -731,6 +740,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return result @classmethod + @dispatch.add_dispatch_support def from_nested_row_lengths(cls, flat_values, nested_row_lengths, @@ -1307,6 +1317,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, A `RaggedTensor`. `result.rank = 1 + new_values.rank`. `result.ragged_rank = 1 + new_values.ragged_rank` """ + new_values = convert_to_tensor_or_ragged_tensor(new_values) new_values.shape.with_rank_at_least(1) self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) if (isinstance(new_values, RaggedTensor) and @@ -1339,7 +1350,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, if isinstance(self._values, RaggedTensor): return self.with_values(self.values.with_flat_values(new_values)) else: - _assert_is_supported_ragged_values_type(new_values) + new_values = convert_to_tensor_or_ragged_tensor(new_values) return self.with_values(new_values) def with_row_splits_dtype(self, dtype): @@ -1479,6 +1490,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, #============================================================================= @classmethod + @dispatch.add_dispatch_support def from_tensor(cls, tensor, lengths=None, @@ -1751,6 +1763,7 @@ class RaggedTensor(composite_tensor.CompositeTensor, return tensor @classmethod + @dispatch.add_dispatch_support def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64): """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`. @@ -2521,8 +2534,8 @@ def convert_to_tensor_or_ragged_tensor(value, return RaggedTensor.from_nested_row_splits( flat_values, value.nested_row_splits, validate=False) else: - return ops.convert_to_tensor( - value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name) + return ops.convert_to_tensor_v2_with_dispatch( + value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name) def _convert_to_ragged_tensor_values(value):