Add support for RaggedTensor properties, instance methods, and class methods in the Keras functional API.

PiperOrigin-RevId: 337337772
Change-Id: I4b66330078049b13ef5c8eddfbb77d895a65a9d8
This commit is contained in:
Tomer Kaftan 2020-10-15 10:34:54 -07:00 committed by TensorFlower Gardener
parent df6a423036
commit 9adbacfd41
3 changed files with 553 additions and 4 deletions
tensorflow/python

View File

@ -20,15 +20,20 @@ from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import func_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import keras_parameterized
from tensorflow.python.keras import layers
from tensorflow.python.keras import testing_utils
from tensorflow.python.keras.engine import training
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import test
from tensorflow.python.util import nest
class RaggedKerasTensorTest(keras_parameterized.TestCase):
@ -89,6 +94,278 @@ class RaggedKerasTensorTest(keras_parameterized.TestCase):
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
self.assertAllEqual(model(x), x / x)
@parameterized.parameters(
{'property_name': 'values'},
{'property_name': 'flat_values'},
{'property_name': 'row_splits'},
{'property_name': 'nested_row_splits'},
)
def test_instance_property(self, property_name):
inp = layers.Input(shape=[None], ragged=True)
out = getattr(inp, property_name)
model = training.Model(inp, out)
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
expected_property = getattr(x, property_name)
self.assertAllEqual(model(x), expected_property)
# Test that it works with serialization and deserialization as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected_property)
@parameterized.parameters(
{'name': 'value_rowids'},
{'name': 'nested_value_rowids'},
{'name': 'nrows'},
{'name': 'row_starts'},
{'name': 'row_limits'},
{'name': 'row_lengths'},
{'name': 'nested_row_lengths'},
{'name': 'bounding_shape'},
{
'name': 'with_values',
'args': [[1, 2, 3, 4, 5, 6]]
},
{
'name': 'with_flat_values',
'kwargs': {
'new_values': [1, 2, 3, 4, 5, 6]
}
},
{
'name': 'with_row_splits_dtype',
'kwargs': {
'dtype': dtypes.int32
}
},
{
'name': 'merge_dims',
'args': [0],
'kwargs': {
'inner_axis': 1
}
},
{'name': 'to_tensor'},
{'name': 'to_sparse'},
)
def test_instance_method(self, name, args=None, kwargs=None):
if not args:
args = []
if not kwargs:
kwargs = {}
inp = layers.Input(shape=[None], ragged=True)
out = getattr(inp, name)(*args, **kwargs)
model = training.Model(inp, out)
x = ragged_factory_ops.constant([[3, 4], [1, 2], [3, 5]])
expected_property = getattr(x, name)(*args, **kwargs)
# We expand composites before checking equality because
# assertAllEqual otherwise wouldn't work for SparseTensor outputs
for a, b in zip(nest.flatten(model(x), expand_composites=True),
nest.flatten(expected_property, expand_composites=True)):
self.assertAllEqual(a, b)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
for a, b in zip(nest.flatten(model2(x), expand_composites=True),
nest.flatten(expected_property, expand_composites=True)):
self.assertAllEqual(a, b)
class RaggedTensorClassMethodAsLayerTest(keras_parameterized.TestCase):
def test_from_value_rowids(self):
inp = layers.Input(shape=[None])
out = ragged_tensor.RaggedTensor.from_value_rowids(
inp, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
model = training.Model(inp, out)
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
expected = ragged_tensor.RaggedTensor.from_value_rowids(
x, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_row_splits(self):
inp = layers.Input(shape=[None])
out = ragged_tensor.RaggedTensor.from_row_splits(
inp, row_splits=[0, 4, 4, 7, 8, 8])
model = training.Model(inp, out)
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
expected = ragged_tensor.RaggedTensor.from_row_splits(
x, row_splits=[0, 4, 4, 7, 8, 8])
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_row_lengths(self):
inp = layers.Input(shape=[None])
out = ragged_tensor.RaggedTensor.from_row_lengths(
inp, row_lengths=[4, 0, 3, 1, 0])
model = training.Model(inp, out)
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
expected = ragged_tensor.RaggedTensor.from_row_lengths(
x, row_lengths=[4, 0, 3, 1, 0])
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_row_starts(self):
inp = layers.Input(shape=[None])
out = ragged_tensor.RaggedTensor.from_row_starts(
inp, row_starts=[0, 4, 4, 7, 8])
model = training.Model(inp, out)
x = constant_op.constant([3, 1, 4, 1, 5, 9, 2, 6])
expected = ragged_tensor.RaggedTensor.from_row_starts(
x, row_starts=[0, 4, 4, 7, 8])
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_row_limits(self):
row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
inp = layers.Input(shape=[None], dtype=dtypes.string)
out = ragged_tensor.RaggedTensor.from_row_limits(
inp, row_limits, validate=False)
model = training.Model(inp, out)
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
expected = ragged_tensor.RaggedTensor.from_row_limits(
x, row_limits, validate=False)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_uniform_row_length(self):
inp = layers.Input(shape=[None])
out = ragged_tensor.RaggedTensor.from_uniform_row_length(inp, 2, 8)
model = training.Model(inp, out)
x = constant_op.constant(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16])
expected = ragged_tensor.RaggedTensor.from_uniform_row_length(x, 2, 8)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_nested_value_row_ids(self):
nested_value_rowids = [
constant_op.constant([0, 0, 1, 3, 3], dtypes.int64),
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
]
inp = layers.Input(shape=[None], dtype=dtypes.string)
out = ragged_tensor.RaggedTensor.from_nested_value_rowids(
inp, nested_value_rowids)
model = training.Model(inp, out)
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
expected = ragged_tensor.RaggedTensor.from_nested_value_rowids(
x, nested_value_rowids)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_nested_row_splits(self):
nested_row_splits = [
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
]
inp = layers.Input(shape=[None], dtype=dtypes.string)
out = ragged_tensor.RaggedTensor.from_nested_row_splits(
inp, nested_row_splits)
model = training.Model(inp, out)
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
expected = ragged_tensor.RaggedTensor.from_nested_row_splits(
x, nested_row_splits)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_nested_row_lengths(self):
nested_row_lengths = [
constant_op.constant([2, 1, 0, 2], dtypes.int64),
constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
]
inp = layers.Input(shape=[None], dtype=dtypes.string)
out = ragged_tensor.RaggedTensor.from_nested_row_lengths(
inp, nested_row_lengths)
model = training.Model(inp, out)
x = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
expected = ragged_tensor.RaggedTensor.from_nested_row_lengths(
x, nested_row_lengths)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_tensor(self):
inp = layers.Input(shape=[None], ragged=False)
out = ragged_tensor.RaggedTensor.from_tensor(inp)
model = training.Model(inp, out)
x = constant_op.constant([[3., 4.], [1., 2.], [3., 5.]])
expected = ragged_tensor.RaggedTensor.from_tensor(x)
self.assertAllEqual(model(x), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(x), expected)
def test_from_sparse(self):
inp = layers.Input(shape=[None], sparse=True, dtype=dtypes.string)
out = ragged_tensor.RaggedTensor.from_sparse(inp)
model = training.Model(inp, out)
indices = [[0, 0], [1, 0], [1, 1], [2, 0]]
values = [b'a', b'b', b'c', b'd']
shape = [4, 5]
sp_value = sparse_tensor.SparseTensor(indices, values, shape)
expected = ragged_tensor.RaggedTensor.from_sparse(sp_value)
self.assertAllEqual(model(sp_value), expected)
# Test that the model can serialize and deserialize as well
model_config = model.get_config()
model2 = training.Model.from_config(model_config)
self.assertAllEqual(model2(sp_value), expected)
if __name__ == '__main__':
ops.enable_eager_execution()

View File

@ -54,6 +54,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import tf_logging
from tensorflow.python.training.tracking import base as trackable
from tensorflow.python.util import dispatch
@ -1542,3 +1543,261 @@ for slicing_op in [array_ops._slice_helper, # pylint: disable=protected-access
array_ops.boolean_mask,
array_ops.boolean_mask_v2]:
TFSlicingOpDispatcher(slicing_op).register(slicing_op)
class InstanceProperty(Layer):
"""Wraps an instance property access (e.g. `x.foo`) in a Keras Layer.
This layer takes an attribute name `attr_name` in the constructor and,
when called on input tensor `obj` returns `obj.attr_name`.
KerasTensors specialized for specific extension types use it to
represent instance property accesses on the represented object in the
case where the property needs to be dynamically accessed as opposed to
being statically computed from the typespec, e.g.
x = keras.Input(..., ragged=True)
out = x.flat_values
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, attr_name, **kwargs):
self.attr_name = attr_name
if 'name' not in kwargs:
kwargs['name'] = K.unique_object_name(
'input.' + self.attr_name, zero_based=True, avoid_observed_names=True)
kwargs['autocast'] = False
# Do not individually trace op layers in the SavedModel.
self._must_restore_from_config = True
super(InstanceProperty, self).__init__(**kwargs)
# Preserve all argument data structures when saving/loading a config
# (e.g., don't unnest lists that contain one element)
self._preserve_input_structure_in_config = True
def call(self, obj):
return getattr(obj, self.attr_name)
def get_config(self):
config = {
'attr_name': self.attr_name
}
base_config = super(InstanceProperty, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
return cls(**config)
class InstanceMethod(InstanceProperty):
"""Wraps an instance method access (e.g. `x.foo(arg)` in a Keras Layer.
This layer takes an attribute name `attr_name` in the constructor and,
when called on input tensor `obj` with additional arguments `args` and
`kwargs` returns `obj.attr_name(*args, **kwargs)`.
KerasTensors specialized for specific extension types use it to
represent dynamic instance method calls on the represented object, e.g.
x = keras.Input(..., ragged=True)
new_values = keras.Input(...)
out = x.with_values(new_values)
"""
def call(self, obj, args, kwargs):
method = getattr(obj, self.attr_name)
return method(*args, **kwargs)
def _delegate_property(keras_tensor_cls, property_name): # pylint: disable=invalid-name
"""Register property on a KerasTensor class.
Calling this multiple times with the same arguments should be a no-op.
This method exposes a property on the KerasTensor class that will use an
`InstanceProperty` layer to access the property on the represented
intermediate values in the model.
Arguments:
keras_tensor_cls: The KerasTensor subclass that should expose the property.
property_name: The name of the property to expose and delegate to the
represented (Composite)Tensor.
"""
# We use a lambda because we can't create a Keras layer at import time
# due to dynamic layer class versioning.
property_access = property(lambda self: InstanceProperty(property_name)(self)) # pylint: disable=unnecessary-lambda
setattr(keras_tensor_cls, property_name, property_access)
def _delegate_method(keras_tensor_cls, method_name): # pylint: disable=invalid-name
"""Register method on a KerasTensor class.
Calling this function times with the same arguments should be a no-op.
This method exposes an instance method on the KerasTensor class that will use
an `InstanceMethod` layer to run the desired method on the represented
intermediate values in the model.
Arguments:
keras_tensor_cls: The KerasTensor subclass that should expose the property.
method_name: The name of the method to expose and delegate to the
represented (Composite)Tensor.
"""
def delegate(self, *args, **kwargs):
return InstanceMethod(method_name)(self, args, kwargs)
setattr(keras_tensor_cls, method_name, delegate)
# We do not support the `uniform_row_length` property because it
# returns either `None` or an int tensor, and code that relies on it tends
# to check `is None` directly. Delegating it here would always return a
# `KerasTensor`, regardless of what can be statically inferred. This would
# never equal `None`, breaking code that expects it to be partially-static
# in unpredictable ways.
for ragged_property in [
'values',
'flat_values',
'row_splits',
'nested_row_splits'
]:
_delegate_property(keras_tensor.RaggedKerasTensor, ragged_property)
for ragged_method_name in [
'value_rowids',
'nested_value_rowids',
'nrows',
'row_starts',
'row_limits',
'row_lengths',
'nested_row_lengths',
'bounding_shape',
'with_values',
'with_flat_values',
'with_row_splits_dtype',
'merge_dims',
'to_tensor',
'to_sparse',
]:
_delegate_method(keras_tensor.RaggedKerasTensor, ragged_method_name)
for sparse_property in [
'indices',
'values',
]:
_delegate_property(keras_tensor.SparseKerasTensor, sparse_property)
for sparse_method in [
'with_values',
]:
_delegate_method(keras_tensor.SparseKerasTensor, sparse_method)
class ClassMethod(Layer):
"""Wraps a TF API Class's class method in a `Layer` object.
It is inserted by the Functional API construction whenever users call
a supported TF Class's class method on KerasTensors.
This is useful in the case where users do something like:
x = keras.Input(...)
y = keras.Input(...)
out = tf.RaggedTensor.from_row_splits(x, y)
"""
@trackable.no_automatic_dependency_tracking
def __init__(self, cls_ref, method_name, **kwargs):
self.cls_ref = cls_ref
self.method_name = method_name
self.cls_symbol = (
get_canonical_name_for_symbol(self.cls_ref,
add_prefix_to_v1_names=True) or
get_canonical_name_for_symbol(self.cls_ref,
api_name='keras',
add_prefix_to_v1_names=True))
if 'name' not in kwargs:
kwargs['name'] = K.unique_object_name(
'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True,
avoid_observed_names=True)
kwargs['autocast'] = False
# Do not individually trace op layers in the SavedModel.
self._must_restore_from_config = True
super(ClassMethod, self).__init__(**kwargs)
# Preserve all argument data structures when saving/loading a config
# (e.g., don't unnest lists that contain one element)
self._preserve_input_structure_in_config = True
self._expects_training_arg = False
self._expects_mask_arg = False
def call(self, args, kwargs):
return getattr(self.cls_ref, self.method_name)(*args, **kwargs)
def get_config(self):
if not self.cls_symbol:
raise ValueError('This Keras class method conversion tried to convert '
'a method belonging to class %s, a class '
'that is not an exposed in the TensorFlow API. '
'To ensure cross-version compatibility of Keras models '
'that use op layers, only op layers produced from '
'exported TF API symbols can be serialized.'
% self.cls_symbol)
config = {
'cls_symbol': self.cls_symbol,
'method_name': self.method_name
}
base_config = super(ClassMethod, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
symbol_name = config.pop('cls_symbol')
cls_ref = get_symbol_from_name(symbol_name)
if not cls_ref:
raise ValueError(
'TF symbol `tf.%s` could not be found.' % symbol_name)
config['cls_ref'] = cls_ref
return cls(**config)
class TFClassMethodDispatcher(dispatch.OpDispatcher):
"""A class method dispatcher that allows building a functional model with TF class methods."""
def __init__(self, cls, method_name):
self.cls = cls
self.method_name = method_name
def handle(self, args, kwargs):
"""Handle the specified operation with the specified arguments."""
if any(
isinstance(x, keras_tensor.KerasTensor)
for x in nest.flatten([args, kwargs])):
return ClassMethod(self.cls, self.method_name)(args[1:], kwargs)
else:
return self.NOT_SUPPORTED
for ragged_class_method in [
'from_value_rowids',
'from_row_splits',
'from_row_lengths',
'from_row_starts',
'from_row_limits',
'from_uniform_row_length',
'from_nested_value_rowids',
'from_nested_row_splits',
'from_nested_row_lengths',
'from_tensor',
'from_sparse',
]:
TFClassMethodDispatcher(
ragged_tensor.RaggedTensor, ragged_class_method).register(
getattr(ragged_tensor.RaggedTensor, ragged_class_method))

View File

@ -44,6 +44,7 @@ from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged import ragged_util
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.types import internal as internal_types
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs import doc_controls
@ -341,6 +342,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
row_partition=row_partition)
@classmethod
@dispatch.add_dispatch_support
def from_value_rowids(cls,
values,
value_rowids,
@ -399,6 +401,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return cls._from_row_partition(values, row_partition, validate=validate)
@classmethod
@dispatch.add_dispatch_support
def from_row_splits(cls, values, row_splits, name=None, validate=True):
"""Creates a `RaggedTensor` with rows partitioned by `row_splits`.
@ -445,6 +448,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return cls._from_row_partition(values, row_partition, validate=validate)
@classmethod
@dispatch.add_dispatch_support
def from_row_lengths(cls, values, row_lengths, name=None, validate=True):
"""Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
@ -487,6 +491,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return cls._from_row_partition(values, row_partition, validate=validate)
@classmethod
@dispatch.add_dispatch_support
def from_row_starts(cls, values, row_starts, name=None, validate=True):
"""Creates a `RaggedTensor` with rows partitioned by `row_starts`.
@ -526,6 +531,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return cls._from_row_partition(values, row_partition, validate=validate)
@classmethod
@dispatch.add_dispatch_support
def from_row_limits(cls, values, row_limits, name=None, validate=True):
"""Creates a `RaggedTensor` with rows partitioned by `row_limits`.
@ -562,6 +568,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return cls._from_row_partition(values, row_partition, validate=validate)
@classmethod
@dispatch.add_dispatch_support
def from_uniform_row_length(cls,
values,
uniform_row_length,
@ -636,6 +643,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return cls._from_row_partition(values, row_partition, validate=validate)
@classmethod
@dispatch.add_dispatch_support
def from_nested_value_rowids(cls,
flat_values,
nested_value_rowids,
@ -692,6 +700,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return result
@classmethod
@dispatch.add_dispatch_support
def from_nested_row_splits(cls,
flat_values,
nested_row_splits,
@ -731,6 +740,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return result
@classmethod
@dispatch.add_dispatch_support
def from_nested_row_lengths(cls,
flat_values,
nested_row_lengths,
@ -1307,6 +1317,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
A `RaggedTensor`. `result.rank = 1 + new_values.rank`.
`result.ragged_rank = 1 + new_values.ragged_rank`
"""
new_values = _convert_to_ragged_tensor_values(new_values)
new_values.shape.with_rank_at_least(1)
self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
if (isinstance(new_values, RaggedTensor) and
@ -1339,8 +1350,8 @@ class RaggedTensor(composite_tensor.CompositeTensor,
if isinstance(self._values, RaggedTensor):
return self.with_values(self.values.with_flat_values(new_values))
else:
_assert_is_supported_ragged_values_type(new_values)
return self.with_values(new_values)
new_values = _convert_to_ragged_tensor_values(new_values)
return self.with_values(new_values)
def with_row_splits_dtype(self, dtype):
"""Returns a copy of this RaggedTensor with the given `row_splits` dtype.
@ -1479,6 +1490,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
#=============================================================================
@classmethod
@dispatch.add_dispatch_support
def from_tensor(cls,
tensor,
lengths=None,
@ -1751,6 +1763,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
return tensor
@classmethod
@dispatch.add_dispatch_support
def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
"""Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`.
@ -2521,8 +2534,8 @@ def convert_to_tensor_or_ragged_tensor(value,
return RaggedTensor.from_nested_row_splits(
flat_values, value.nested_row_splits, validate=False)
else:
return ops.convert_to_tensor(
value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name)
return ops.convert_to_tensor_v2_with_dispatch(
value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name)
def _convert_to_ragged_tensor_values(value):