Experimental support of _add_supported_value_type
for RaggedTensor.
It allows RaggedTensor.values to have types other than just RaggedTensor and Tensor. Based on draft implementation by edloper@. PiperOrigin-RevId: 329733234 Change-Id: Iad61506d34a72e97a9611008cdc20258e1f2bd8e
This commit is contained in:
parent
d830e92434
commit
7921c4d517
@ -481,6 +481,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_tensor_test_ops",
|
||||
srcs = ["ragged_tensor_test_ops.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bitwise_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:string_ops",
|
||||
],
|
||||
)
|
||||
|
||||
#-------------------------------------------------------------------------------
|
||||
# RaggedTensor Tests
|
||||
#-------------------------------------------------------------------------------
|
||||
@ -1056,17 +1069,19 @@ py_test(
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged", # fixdeps: keep
|
||||
":ragged_dispatch",
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_test_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:bitwise_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:data_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:sparse_tensor",
|
||||
"//tensorflow/python:string_ops",
|
||||
@ -1296,3 +1311,32 @@ py_test(
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "ragged_tensor_supported_values_test",
|
||||
srcs = ["ragged_tensor_supported_values_test.py"],
|
||||
python_version = "PY3",
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":ragged_factory_ops",
|
||||
":ragged_tensor",
|
||||
":ragged_tensor_test_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:clip_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:nn_ops",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:string_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
@ -30,115 +30,15 @@ from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import gen_bitwise_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_dispatch
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_test_ops as test_ops
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
# Constants listing various op types to test. Each operation
|
||||
# should be included in at least one list below, or tested separately if
|
||||
# necessary (e.g., because it expects additional arguments).
|
||||
UNARY_FLOAT_OPS = [
|
||||
math_ops.abs,
|
||||
math_ops.acos,
|
||||
math_ops.acosh,
|
||||
math_ops.angle,
|
||||
math_ops.asin,
|
||||
math_ops.asinh,
|
||||
math_ops.atan,
|
||||
math_ops.atanh,
|
||||
math_ops.ceil,
|
||||
math_ops.conj,
|
||||
math_ops.cos,
|
||||
math_ops.cosh,
|
||||
math_ops.digamma,
|
||||
math_ops.erf,
|
||||
math_ops.erfc,
|
||||
math_ops.erfinv,
|
||||
math_ops.exp,
|
||||
math_ops.expm1,
|
||||
math_ops.floor,
|
||||
math_ops.imag,
|
||||
math_ops.is_finite,
|
||||
math_ops.is_inf,
|
||||
math_ops.is_nan,
|
||||
math_ops.lgamma,
|
||||
math_ops.log,
|
||||
math_ops.log1p,
|
||||
math_ops.log_sigmoid,
|
||||
math_ops.ndtri,
|
||||
math_ops.negative,
|
||||
math_ops.real,
|
||||
math_ops.reciprocal,
|
||||
math_ops.rint,
|
||||
math_ops.round,
|
||||
math_ops.rsqrt,
|
||||
math_ops.sign,
|
||||
math_ops.sin,
|
||||
math_ops.sinh,
|
||||
math_ops.sqrt,
|
||||
math_ops.square,
|
||||
math_ops.tan,
|
||||
array_ops.identity,
|
||||
array_ops.ones_like,
|
||||
array_ops.zeros_like,
|
||||
]
|
||||
UNARY_BOOL_OPS = [
|
||||
math_ops.logical_not,
|
||||
]
|
||||
UNARY_STRING_OPS = [
|
||||
string_ops.decode_base64,
|
||||
string_ops.encode_base64,
|
||||
string_ops.string_strip,
|
||||
parsing_ops.decode_compressed,
|
||||
]
|
||||
BINARY_FLOAT_OPS = [
|
||||
math_ops.add,
|
||||
math_ops.atan2,
|
||||
math_ops.complex,
|
||||
math_ops.div_no_nan,
|
||||
math_ops.divide,
|
||||
math_ops.equal,
|
||||
math_ops.floordiv,
|
||||
math_ops.floormod,
|
||||
math_ops.greater,
|
||||
math_ops.greater_equal,
|
||||
math_ops.less,
|
||||
math_ops.less_equal,
|
||||
math_ops.maximum,
|
||||
math_ops.minimum,
|
||||
math_ops.multiply,
|
||||
math_ops.not_equal,
|
||||
math_ops.pow,
|
||||
math_ops.realdiv,
|
||||
math_ops.squared_difference,
|
||||
math_ops.subtract,
|
||||
math_ops.truediv,
|
||||
]
|
||||
BINARY_BOOL_OPS = [
|
||||
math_ops.logical_and,
|
||||
math_ops.logical_or,
|
||||
math_ops.logical_xor,
|
||||
]
|
||||
UNARY_INT_OPS = [
|
||||
gen_bitwise_ops.invert,
|
||||
string_ops.unicode_script,
|
||||
]
|
||||
BINARY_INT_OPS = [
|
||||
gen_bitwise_ops.bitwise_and,
|
||||
gen_bitwise_ops.bitwise_or,
|
||||
gen_bitwise_ops.bitwise_xor,
|
||||
gen_bitwise_ops.left_shift,
|
||||
gen_bitwise_ops.right_shift,
|
||||
math_ops.truncatediv,
|
||||
math_ops.truncatemod,
|
||||
]
|
||||
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@ -183,17 +83,17 @@ class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
# Test each unary op.
|
||||
#=========================================================================
|
||||
[{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 'op': op}
|
||||
for op in UNARY_FLOAT_OPS] +
|
||||
for op in test_ops.UNARY_FLOAT_OPS] +
|
||||
[{'x': ragged_factory_ops.constant_value([[True, False], [True]]),
|
||||
'op': op}
|
||||
for op in UNARY_BOOL_OPS] +
|
||||
for op in test_ops.UNARY_BOOL_OPS] +
|
||||
[{'x': ragged_factory_ops.constant_value([[18, 512], [12412]], np.int32),
|
||||
'op': op}
|
||||
for op in UNARY_INT_OPS] +
|
||||
for op in test_ops.UNARY_INT_OPS] +
|
||||
[{'x': ragged_factory_ops.constant_value([['abcd', 'efgh'],
|
||||
['aabbccdd']]),
|
||||
'op': op}
|
||||
for op in UNARY_STRING_OPS] +
|
||||
for op in test_ops.UNARY_STRING_OPS] +
|
||||
[
|
||||
{'op': clip_ops.clip_by_value,
|
||||
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||
@ -337,20 +237,20 @@ class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
'use_kwargs': ('x',)},
|
||||
] +
|
||||
#=========================================================================
|
||||
# Test each unary op.
|
||||
# Test each binary op.
|
||||
#=========================================================================
|
||||
[{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
||||
'y': ragged_factory_ops.constant_value([[5.0, 1.0], [12.0]]),
|
||||
'op': op}
|
||||
for op in BINARY_FLOAT_OPS] +
|
||||
for op in test_ops.BINARY_FLOAT_OPS] +
|
||||
[{'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
|
||||
'y': ragged_factory_ops.constant_value([[5, 1], [12]]),
|
||||
'op': op}
|
||||
for op in BINARY_INT_OPS] +
|
||||
for op in test_ops.BINARY_INT_OPS] +
|
||||
[{'x': ragged_factory_ops.constant_value([[True, True], [False]]),
|
||||
'y': ragged_factory_ops.constant_value([[False, True], [False]]),
|
||||
'op': op}
|
||||
for op in BINARY_BOOL_OPS]
|
||||
for op in test_ops.BINARY_BOOL_OPS]
|
||||
) # pyformat: disable
|
||||
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
||||
use_kwargs = extra_args.pop('use_kwargs', ())
|
||||
|
@ -261,15 +261,12 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
raise ValueError("RaggedTensor constructor is private; please use one "
|
||||
"of the factory methods instead (e.g., "
|
||||
"RaggedTensor.from_row_lengths())")
|
||||
if not isinstance(values, (RaggedTensor, ops.Tensor)):
|
||||
raise TypeError("values must be a Tensor or RaggedTensor, got %r" %
|
||||
values)
|
||||
_assert_is_supported_ragged_values_type(values)
|
||||
if not isinstance(row_partition, RowPartition):
|
||||
raise TypeError("row_partition must be a RowPartition, got %r" %
|
||||
row_partition)
|
||||
|
||||
# Validate shapes.
|
||||
values = convert_to_tensor_or_ragged_tensor(values)
|
||||
values.shape.with_rank_at_least(1)
|
||||
if isinstance(values, RaggedTensor):
|
||||
# pylint: disable=protected-access
|
||||
@ -506,7 +503,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values)
|
||||
values = _convert_to_ragged_tensor_values(values)
|
||||
row_partition = RowPartition.from_row_starts(
|
||||
row_starts=row_starts,
|
||||
nvals=_nrows(values),
|
||||
@ -611,7 +608,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RaggedFromUniformRowLength",
|
||||
[values, uniform_row_length, nrows]):
|
||||
values = convert_to_tensor_or_ragged_tensor(values)
|
||||
values = _convert_to_ragged_tensor_values(values)
|
||||
uniform_row_length = _convert_row_partition(
|
||||
uniform_row_length, "UniformRowLength",
|
||||
_get_optional_partition_dtype(values))
|
||||
@ -790,7 +787,7 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
(name, row_partition.dtype, values._row_partition.dtype))
|
||||
values = values.with_row_splits_dtype(row_partition.dtype)
|
||||
else:
|
||||
values = ops.convert_to_tensor(values, name="values")
|
||||
values = _convert_to_ragged_tensor_values(values)
|
||||
|
||||
return (values, row_partition)
|
||||
|
||||
@ -1325,10 +1322,11 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
`result.rank = self.ragged_rank + new_values.rank`.
|
||||
`result.ragged_rank = self.ragged_rank + new_values.ragged_rank`.
|
||||
"""
|
||||
if isinstance(self._values, ops.Tensor):
|
||||
return self.with_values(new_values)
|
||||
else:
|
||||
if isinstance(self._values, RaggedTensor):
|
||||
return self.with_values(self.values.with_flat_values(new_values))
|
||||
else:
|
||||
_assert_is_supported_ragged_values_type(new_values)
|
||||
return self.with_values(new_values)
|
||||
|
||||
def with_row_splits_dtype(self, dtype):
|
||||
"""Returns a copy of this RaggedTensor with the given `row_splits` dtype.
|
||||
@ -2149,7 +2147,10 @@ def match_row_splits_dtypes(*tensors, **kwargs):
|
||||
class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""Type specification for a `tf.RaggedTensor`."""
|
||||
|
||||
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype"]
|
||||
__slots__ = [
|
||||
"_shape", "_dtype", "_ragged_rank", "_row_splits_dtype",
|
||||
"_flat_values_spec"
|
||||
]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
@ -2211,7 +2212,7 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
|
||||
@property
|
||||
def row_splits_dtype(self):
|
||||
"""The `tf.dtypes.DType` of the the RaggedTensor's `row_splits`.
|
||||
"""The `tf.dtypes.DType` of the RaggedTensor's `row_splits`.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -2225,6 +2226,16 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
"""
|
||||
return self._row_splits_dtype
|
||||
|
||||
@property
|
||||
def flat_values_spec(self):
|
||||
"""The `TypeSpec` of the flat_values of RaggedTensor.
|
||||
|
||||
Returns:
|
||||
- The TypeSpec of flat_values.
|
||||
- None when the flat_values is a Tensor.
|
||||
"""
|
||||
return self._flat_values_spec
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
|
||||
@ -2233,7 +2244,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
ragged_rank=None,
|
||||
row_splits_dtype=dtypes.int64):
|
||||
row_splits_dtype=dtypes.int64,
|
||||
flat_values_spec=None):
|
||||
"""Constructs a type specification for a `tf.RaggedTensor`.
|
||||
|
||||
Args:
|
||||
@ -2244,10 +2256,23 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
flat_values is partitioned. Defaults to `shape.ndims - 1`.
|
||||
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
|
||||
of `tf.int32` or `tf.int64`.
|
||||
flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be
|
||||
provided when the flat_values is a CompositeTensor rather then Tensor.
|
||||
If both `dtype` and `flat_values_spec` and are provided, `dtype` must
|
||||
be the same as `flat_values_spec.dtype`. (experimental)
|
||||
"""
|
||||
self._shape = tensor_shape.as_shape(shape)
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
|
||||
if flat_values_spec is not None:
|
||||
if dtype is None:
|
||||
dtype = flat_values_spec.dtype
|
||||
elif dtype != flat_values_spec.dtype:
|
||||
raise ValueError("dtype must be the same as flat_values_spec.dtype")
|
||||
elif dtype is None:
|
||||
raise ValueError(
|
||||
"At least one of dtype or flat_values_spec must be provided")
|
||||
self._dtype = dtypes.as_dtype(dtype)
|
||||
self._flat_values_spec = flat_values_spec
|
||||
|
||||
rank = self._shape.ndims
|
||||
if ragged_rank is None:
|
||||
@ -2264,29 +2289,43 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
raise ValueError("ragged_rank must be less than rank.")
|
||||
|
||||
def is_compatible_with(self, spec_or_value):
|
||||
if (self._ragged_rank == 0 and
|
||||
isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec))):
|
||||
return tensor_spec.TensorSpec(
|
||||
self._shape, self._dtype).is_compatible_with(spec_or_value)
|
||||
else:
|
||||
return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
|
||||
# RaggedTensor with ragged_rank 0 can be compatible with raw flat_values.
|
||||
if self._ragged_rank == 0:
|
||||
if self._flat_values_spec is None:
|
||||
if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)):
|
||||
return tensor_spec.TensorSpec(
|
||||
self._shape, self._dtype).is_compatible_with(spec_or_value)
|
||||
elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)):
|
||||
return self._flat_values_spec.is_compatible_with(spec_or_value)
|
||||
return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
|
||||
|
||||
def _serialize(self):
|
||||
return (self._shape, self._dtype, self._ragged_rank, self._row_splits_dtype)
|
||||
if self._flat_values_spec is None:
|
||||
return (self._shape, self._dtype, self._ragged_rank,
|
||||
self._row_splits_dtype)
|
||||
else:
|
||||
return (self._shape, self._dtype, self._ragged_rank,
|
||||
self._row_splits_dtype, self._flat_values_spec)
|
||||
|
||||
@property
|
||||
def _component_specs(self):
|
||||
if self._ragged_rank == 0:
|
||||
return [tensor_spec.TensorSpec(self._shape, self._dtype)]
|
||||
if self._flat_values_spec is not None:
|
||||
return [self._flat_values_spec]
|
||||
else:
|
||||
return [tensor_spec.TensorSpec(self._shape, self._dtype)]
|
||||
|
||||
flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[self._ragged_rank + 1:])
|
||||
flat_values_spec = self._flat_values_spec
|
||||
if flat_values_spec is None:
|
||||
flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[self._ragged_rank + 1:])
|
||||
flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype)
|
||||
outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
|
||||
outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
|
||||
inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
|
||||
|
||||
specs = ([
|
||||
tensor_spec.TensorSpec(flat_values_shape, self._dtype),
|
||||
flat_values_spec,
|
||||
tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)
|
||||
] + [inner_splits_spec for _ in range(self._ragged_rank - 1)])
|
||||
return specs
|
||||
@ -2328,6 +2367,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
def _to_tensor_list(self, value):
|
||||
# TODO(edloper): Update gen_ragged_conversion_ops that convert to and
|
||||
# from variant to include all of the row-partitioning tensors.
|
||||
if self._flat_values_spec is not None:
|
||||
raise ValueError("Customized value_type is not supported")
|
||||
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||
if ragged_rank != self._ragged_rank:
|
||||
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||
@ -2341,6 +2382,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return [value._to_variant(batched_input=False)]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
if self._flat_values_spec is not None:
|
||||
raise ValueError("Customized value_type is not supported")
|
||||
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||
if ragged_rank != self._ragged_rank:
|
||||
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||
@ -2353,6 +2396,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return [value._to_variant(batched_input=True)]
|
||||
|
||||
def _from_compatible_tensor_list(self, tensor_list):
|
||||
if self._flat_values_spec is not None:
|
||||
raise ValueError("Customized value_type is not supported")
|
||||
if self._ragged_rank < 0:
|
||||
raise ValueError("ragged_rank must be non-negative; got %s." %
|
||||
self._ragged_rank)
|
||||
@ -2372,11 +2417,15 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return result
|
||||
|
||||
def _batch(self, batch_size):
|
||||
if self._flat_values_spec is not None:
|
||||
raise ValueError("Customized value_type is not supported")
|
||||
return RaggedTensorSpec(
|
||||
tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
|
||||
self._dtype, self._ragged_rank + 1, self._row_splits_dtype)
|
||||
|
||||
def _unbatch(self):
|
||||
if self._flat_values_spec is not None:
|
||||
raise ValueError("Customized value_type is not supported")
|
||||
# Note: Negative ragged_rank is allowed here because the dataset could be
|
||||
# subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is
|
||||
# consistent. Errors are handled in
|
||||
@ -2395,11 +2444,20 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value):
|
||||
return cls(
|
||||
shape=value.shape,
|
||||
dtype=value.values.dtype,
|
||||
ragged_rank=value.ragged_rank,
|
||||
row_splits_dtype=value.row_splits.dtype)
|
||||
if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or
|
||||
isinstance(value.flat_values, ops.Tensor)):
|
||||
return cls(
|
||||
shape=value.shape,
|
||||
dtype=value.values.dtype,
|
||||
ragged_rank=value.ragged_rank,
|
||||
row_splits_dtype=value.row_splits.dtype)
|
||||
else:
|
||||
return cls(
|
||||
shape=value.shape,
|
||||
dtype=value.values.dtype,
|
||||
ragged_rank=value.ragged_rank,
|
||||
row_splits_dtype=value.row_splits.dtype,
|
||||
flat_values_spec=type_spec.type_spec_from_value(value.flat_values))
|
||||
|
||||
|
||||
type_spec.register_type_spec_from_value_converter(
|
||||
@ -2453,6 +2511,27 @@ def convert_to_tensor_or_ragged_tensor(value,
|
||||
value=value, dtype=dtype, preferred_dtype=preferred_dtype, name=name)
|
||||
|
||||
|
||||
def _convert_to_ragged_tensor_values(value):
|
||||
"""Converts value to supported RaggedTensor value.
|
||||
|
||||
* If `value` is an object of supported value type, then return it as-is.
|
||||
* Otherwise convert it to Tensor or RaggedTensor.
|
||||
|
||||
Args:
|
||||
value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
|
||||
value types, or an object whose type has a registered `Tensor`
|
||||
conversion function.
|
||||
|
||||
Returns:
|
||||
An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
|
||||
value types
|
||||
"""
|
||||
if _is_supported_ragged_values_type(value):
|
||||
return value
|
||||
else:
|
||||
return convert_to_tensor_or_ragged_tensor(value, name="values")
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Register RaggedTensor for use with session.run.
|
||||
#===============================================================================
|
||||
@ -2771,3 +2850,67 @@ def _get_optional_partition_dtype(values):
|
||||
|
||||
|
||||
ops.no_gradient("RaggedTensorToVariant")
|
||||
|
||||
|
||||
_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
|
||||
|
||||
|
||||
# TODO(edloper): Consider whether we should change the registry to be on
|
||||
# TypeSpecs rather than ValueTypes.
|
||||
def _add_supported_value_type(cls):
|
||||
"""Register the `cls` as supported value type of RaggedTenosr.
|
||||
|
||||
The cls must be a subclass of CompositeTensor, and must support:
|
||||
- Properties:
|
||||
- x.shape
|
||||
- x.dtype
|
||||
- Methods:
|
||||
- x.__getitem__(idx) (method: returns a supported value type)
|
||||
- Ops:
|
||||
- tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor.
|
||||
- tf.tile(x)
|
||||
- assert_rank_at_least(x)
|
||||
- tf.ones_like(x)
|
||||
- tf.gather(params=x, indices=Tensor)
|
||||
- tf.add(x, y)
|
||||
- tf.boolean_mask(x, ...)
|
||||
- @TODO(edloper): Complete this list
|
||||
|
||||
Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not
|
||||
currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor:
|
||||
- rt.to_tensor()
|
||||
- rt.to_sparse_tensor()
|
||||
- rt._to_variant()
|
||||
- rt._from_variant()
|
||||
- tf.ragged.cross([rt])
|
||||
- tf.gather(params=x, indices=rt) # rt used for indices
|
||||
- RaggedTensorSpec methods:
|
||||
- _batch
|
||||
- _unbatch
|
||||
- _to_tensor_list
|
||||
- _to_batched_tensor_list
|
||||
- _from_compatible_tensor_list
|
||||
|
||||
Args:
|
||||
cls: The type to be added to supported value types.
|
||||
"""
|
||||
if not issubclass(cls, composite_tensor.CompositeTensor):
|
||||
raise ValueError("cls(%s) must be a subclass of CompositeTensor" % cls)
|
||||
if not hasattr(cls, "shape"):
|
||||
raise ValueError("cls must support the `shape` property")
|
||||
if not hasattr(cls, "dtype"):
|
||||
raise ValueError("cls must support the `dtype` property")
|
||||
global _SUPPORTED_RAGGED_VALUE_TYPES
|
||||
_SUPPORTED_RAGGED_VALUE_TYPES += (cls,)
|
||||
|
||||
|
||||
def _is_supported_ragged_values_type(value):
|
||||
return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES)
|
||||
|
||||
|
||||
def _assert_is_supported_ragged_values_type(value):
|
||||
if not _is_supported_ragged_values_type(value):
|
||||
ok_types = ", ".join(cls.__name__ for cls in
|
||||
_SUPPORTED_RAGGED_VALUE_TYPES)
|
||||
raise TypeError("type(values) must be one of: %r, got %r" %
|
||||
(ok_types, value))
|
||||
|
@ -0,0 +1,500 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for RaggedTensor supported value types."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.framework import type_spec
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_test_ops as test_ops
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
|
||||
from tensorflow.python.platform import googletest
|
||||
from tensorflow.python.util import dispatch
|
||||
|
||||
|
||||
class WrappedTensor(composite_tensor.CompositeTensor):
|
||||
"""A class used to test extending RaggedTensor value type support.
|
||||
|
||||
Simply wraps a `tf.Tensor` value.
|
||||
"""
|
||||
|
||||
def __init__(self, value):
|
||||
if not isinstance(value, ops.Tensor):
|
||||
raise ValueError("Expect Tensor object, but get '%s'" % value)
|
||||
self.value = value
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.value.shape
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.value.dtype
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return WrappedTensor(self.value.__getitem__(idx))
|
||||
|
||||
@property
|
||||
def _type_spec(self):
|
||||
return WrappedTensorSpec(type_spec.type_spec_from_value(self.value))
|
||||
|
||||
|
||||
class WrappedTensorSpec(type_spec.TypeSpec):
|
||||
|
||||
def __init__(self, value_spec):
|
||||
self._value_spec = value_spec
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self._value_spec.dtype
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return WrappedTensor
|
||||
|
||||
def _to_components(self, value):
|
||||
return value.value
|
||||
|
||||
def _from_components(self, value):
|
||||
return WrappedTensor(value)
|
||||
|
||||
def _component_specs(self):
|
||||
return self._value_spec
|
||||
|
||||
def _serialize(self):
|
||||
return (self._value_spec,)
|
||||
|
||||
|
||||
class WrappedTensorOpDispatcher(dispatch.GlobalOpDispatcher):
|
||||
"""Global op dispatcher for WrappedTensor."""
|
||||
|
||||
# For these ops, just return plain Tensors (not WrappedTensors).
|
||||
OPS_THAT_RETURN_UNTRACED_RESULTS = (array_ops.shape, array_ops.shape_v2,
|
||||
check_ops.assert_rank_at_least)
|
||||
|
||||
def call_op(self, op, *args, **kwargs):
|
||||
return op(*args, **kwargs)
|
||||
|
||||
def handle(self, op, args, kwargs):
|
||||
# Dispatcher only applies if at least one arg is a WrappedTensor.
|
||||
if not (any(self.is_wrapped_tensor_arg(x) for x in args) or
|
||||
any(self.is_wrapped_tensor_arg(x) for x in kwargs.values())):
|
||||
return self.NOT_SUPPORTED
|
||||
|
||||
args = [self.unwrap(v) for v in args]
|
||||
kwargs = dict([(k, self.unwrap(v)) for (k, v) in kwargs.items()])
|
||||
value = self.call_op(op, *args, **kwargs)
|
||||
if op in self.OPS_THAT_RETURN_UNTRACED_RESULTS:
|
||||
return value
|
||||
else:
|
||||
return WrappedTensor(value)
|
||||
|
||||
def unwrap(self, value):
|
||||
if isinstance(value, WrappedTensor):
|
||||
return value.value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return type(value)([self.unwrap(v) for v in value])
|
||||
else:
|
||||
return value
|
||||
|
||||
def is_wrapped_tensor_arg(self, value):
|
||||
if isinstance(value, WrappedTensor):
|
||||
return True
|
||||
if isinstance(value, (list, tuple)):
|
||||
if any(isinstance(x, WrappedTensor) for x in value):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
WrappedTensorOpDispatcher().register()
|
||||
ragged_tensor._add_supported_value_type(WrappedTensor)
|
||||
|
||||
|
||||
# pylint: disable=g-complex-comprehension
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorSupportedValuesTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertAllTensorsEqual(self, list1, list2):
|
||||
self.assertLen(list1, len(list2))
|
||||
for (t1, t2) in zip(list1, list2):
|
||||
self.assertAllEqual(t1, t2)
|
||||
|
||||
def testConstruction(self):
|
||||
tensor_values = constant_op.constant(
|
||||
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h'])
|
||||
values = WrappedTensor(tensor_values)
|
||||
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 8], dtypes.int64)
|
||||
rt = RaggedTensor.from_row_splits(values, row_splits)
|
||||
self.assertIsInstance(rt.values, WrappedTensor)
|
||||
self.assertAllEqual(rt.values.value, tensor_values)
|
||||
self.assertAllEqual(rt.row_splits, row_splits)
|
||||
|
||||
row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
|
||||
rt = RaggedTensor.from_row_starts(values, row_starts)
|
||||
self.assertIsInstance(rt.values, WrappedTensor)
|
||||
self.assertAllEqual(rt.values.value, tensor_values)
|
||||
self.assertAllEqual(rt.row_starts(), row_starts)
|
||||
|
||||
row_limits = constant_op.constant([2, 2, 5, 6, 8], dtypes.int64)
|
||||
rt = RaggedTensor.from_row_limits(values, row_limits)
|
||||
self.assertIsInstance(rt.values, WrappedTensor)
|
||||
self.assertAllEqual(rt.values.value, tensor_values)
|
||||
self.assertAllEqual(rt.row_limits(), row_limits)
|
||||
|
||||
row_lengths = constant_op.constant([2, 0, 3, 1, 2], dtypes.int64)
|
||||
rt = RaggedTensor.from_row_lengths(values, row_lengths)
|
||||
self.assertIsInstance(rt.values, WrappedTensor)
|
||||
self.assertAllEqual(rt.values.value, tensor_values)
|
||||
self.assertAllEqual(rt.row_lengths(), row_lengths)
|
||||
|
||||
rt = RaggedTensor.from_uniform_row_length(values, 4)
|
||||
self.assertIsInstance(rt.values, WrappedTensor)
|
||||
self.assertAllEqual(rt.values.value, tensor_values)
|
||||
self.assertAllEqual(rt.uniform_row_length, 4)
|
||||
|
||||
def testWithValues(self):
|
||||
tensor_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
values = WrappedTensor(tensor_values)
|
||||
nested_row_splits = [[0, 2, 5], [0, 2, 2, 5, 6, 7]]
|
||||
rt = RaggedTensor.from_nested_row_splits(values, nested_row_splits)
|
||||
|
||||
tensor_int = constant_op.constant([1, 2, 3, 4, 5])
|
||||
rt_int = rt.with_values(tensor_int)
|
||||
self.assertAllEqual(rt_int.values, tensor_int)
|
||||
|
||||
rt_wrapped_int = rt.with_values(WrappedTensor(tensor_int))
|
||||
self.assertIsInstance(rt_wrapped_int.values, WrappedTensor)
|
||||
self.assertAllEqual(rt_wrapped_int.values.value, tensor_int)
|
||||
|
||||
def testWithFlatValues(self):
|
||||
tensor_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
values = WrappedTensor(tensor_values)
|
||||
nested_row_splits = [[0, 2, 5], [0, 2, 2, 5, 6, 7]]
|
||||
rt = RaggedTensor.from_nested_row_splits(values, nested_row_splits)
|
||||
|
||||
tensor_int = constant_op.constant([1, 2, 3, 4, 5, 6, 7])
|
||||
rt_int = rt.with_flat_values(tensor_int)
|
||||
self.assertAllEqual(rt_int.flat_values, tensor_int)
|
||||
|
||||
rt_wrapped_int = rt.with_flat_values(WrappedTensor(tensor_int))
|
||||
self.assertIsInstance(rt_wrapped_int.flat_values, WrappedTensor)
|
||||
self.assertAllEqual(rt_wrapped_int.flat_values.value, tensor_int)
|
||||
|
||||
@parameterized.parameters(
|
||||
#=========================================================================
|
||||
# Test each unary op.
|
||||
#=========================================================================
|
||||
[{'x': ([[-2.0, 3.0], [-3.0]]), 'op': op}
|
||||
for op in test_ops.UNARY_FLOAT_OPS] +
|
||||
[{'x': ([[True, False], [True]]),
|
||||
'op': op}
|
||||
for op in test_ops.UNARY_BOOL_OPS] +
|
||||
[{'x': [[18, 512], [12412]],
|
||||
'x_dtype': dtypes.int32,
|
||||
'op': op}
|
||||
for op in test_ops.UNARY_INT_OPS] +
|
||||
[{'x': ([['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'op': op}
|
||||
for op in test_ops.UNARY_STRING_OPS] +
|
||||
[
|
||||
{'op': clip_ops.clip_by_value,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'clip_value_min': 0.1, 'clip_value_max': 4.0},
|
||||
{'op': math_ops.cast,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'dtype': dtypes.int32},
|
||||
{'op': math_ops.saturate_cast,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'dtype': dtypes.int32},
|
||||
{'op': string_ops.string_to_hash_bucket,
|
||||
'x': (
|
||||
[['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'num_buckets': 1000},
|
||||
{'op': string_ops.string_to_hash_bucket_fast,
|
||||
'x': (
|
||||
[['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'num_buckets': 1000},
|
||||
{'op': string_ops.string_to_hash_bucket_strong,
|
||||
'x': (
|
||||
[['abcd', 'efgh'], ['aabbccdd']]),
|
||||
'num_buckets': 1000,
|
||||
'key': [1231, 12512]},
|
||||
{'op': string_ops.string_to_number,
|
||||
'x': ([['-2.0', '3.0'], ['-3.0']])},
|
||||
{'op': string_ops.regex_full_match,
|
||||
'x': ([['hello', '123'], ['1+1']]),
|
||||
'pattern': r'\w+'},
|
||||
{'op': string_ops.regex_replace,
|
||||
'x': ([['hello', '123'], ['1+1']]),
|
||||
'pattern': r'\d',
|
||||
'rewrite': '#'},
|
||||
{'op': string_ops.substr,
|
||||
'x': ([['hello', '123'], ['1+1']]),
|
||||
'pos': 2, 'len': 3},
|
||||
{'op': array_ops.check_numerics,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'message': 'check-numerics'},
|
||||
{'op': nn_ops.dropout,
|
||||
'x': ([[-2.0, 3.0], [-3.0]]),
|
||||
'rate': 0.5,
|
||||
'seed': 1},
|
||||
]) # pyformat: disable
|
||||
def testUnaryElementwiseOp(self,
|
||||
x,
|
||||
x_dtype=None,
|
||||
op=math_ops.abs,
|
||||
**extra_args):
|
||||
x = ragged_factory_ops.constant(x, x_dtype)
|
||||
wrapped_x = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
WrappedTensor(x.flat_values), x.nested_row_splits)
|
||||
test_util.random_seed.set_seed(1234)
|
||||
res = op(x, **extra_args)
|
||||
test_util.random_seed.set_seed(1234)
|
||||
wrapped_res = op(wrapped_x, **extra_args)
|
||||
self.assertIsInstance(wrapped_res.flat_values, WrappedTensor)
|
||||
self.assertAllEqual(wrapped_res.flat_values.value, res.flat_values)
|
||||
self.assertAllTensorsEqual(wrapped_res.nested_row_splits,
|
||||
res.nested_row_splits)
|
||||
|
||||
@parameterized.parameters(
|
||||
#=========================================================================
|
||||
# Test each binary op.
|
||||
#=========================================================================
|
||||
[{'x': [[-2.0, 3.0], [-3.0]],
|
||||
'y': [[5.0, 1.0], [12.0]],
|
||||
'op': op}
|
||||
for op in test_ops.BINARY_FLOAT_OPS] +
|
||||
[{'x': [[-2, 3], [-3]],
|
||||
'y': [[5, 1], [12]],
|
||||
'op': op}
|
||||
for op in test_ops.BINARY_INT_OPS] +
|
||||
[{'x': [[True, True], [False]],
|
||||
'y': [[False, True], [False]],
|
||||
'op': op}
|
||||
for op in test_ops.BINARY_BOOL_OPS]
|
||||
) # pyformat: disable
|
||||
def testBinaryElementwiseOp(self, x, y, op=math_ops.add):
|
||||
x = ragged_factory_ops.constant(x)
|
||||
y = ragged_factory_ops.constant(y)
|
||||
wrapped_x = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
WrappedTensor(x.flat_values), x.nested_row_splits)
|
||||
wrapped_y = ragged_tensor.RaggedTensor.from_nested_row_splits(
|
||||
WrappedTensor(y.flat_values), y.nested_row_splits)
|
||||
res = op(x, y)
|
||||
wrapped_res = op(wrapped_x, wrapped_y)
|
||||
self.assertIsInstance(wrapped_res.flat_values, WrappedTensor)
|
||||
self.assertAllEqual(wrapped_res.flat_values.value, res.flat_values)
|
||||
self.assertAllTensorsEqual(wrapped_res.nested_row_splits,
|
||||
res.nested_row_splits)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorSpecSupportedValuesTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
|
||||
def assertAllTensorsEqual(self, list1, list2):
|
||||
self.assertLen(list1, len(list2))
|
||||
for (t1, t2) in zip(list1, list2):
|
||||
self.assertAllEqual(t1, t2)
|
||||
|
||||
def testConstruction(self):
|
||||
flat_values_spec = WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec(shape=(None, 5), dtype=dtypes.float32))
|
||||
spec1 = RaggedTensorSpec(
|
||||
shape=None,
|
||||
dtype=dtypes.float32,
|
||||
ragged_rank=1,
|
||||
row_splits_dtype=dtypes.int64,
|
||||
flat_values_spec=flat_values_spec)
|
||||
self.assertIsNone(spec1._shape.rank)
|
||||
self.assertEqual(spec1._dtype, dtypes.float32)
|
||||
self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
|
||||
self.assertEqual(spec1._ragged_rank, 1)
|
||||
self.assertEqual(spec1._flat_values_spec, flat_values_spec)
|
||||
|
||||
self.assertIsNone(spec1.shape.rank)
|
||||
self.assertEqual(spec1.dtype, dtypes.float32)
|
||||
self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
|
||||
self.assertEqual(spec1.ragged_rank, 1)
|
||||
self.assertEqual(spec1.flat_values_spec, flat_values_spec)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'dtype must be the same as flat_values_spec.dtype'):
|
||||
spec1 = RaggedTensorSpec(
|
||||
shape=None,
|
||||
dtype=dtypes.float64,
|
||||
ragged_rank=1,
|
||||
row_splits_dtype=dtypes.int64,
|
||||
flat_values_spec=flat_values_spec)
|
||||
|
||||
@parameterized.parameters([
|
||||
(RaggedTensorSpec(
|
||||
ragged_rank=1,
|
||||
flat_values_spec=tensor_spec.TensorSpec(None, dtypes.float32)),
|
||||
(tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64,
|
||||
tensor_spec.TensorSpec(None, dtypes.float32))),
|
||||
(RaggedTensorSpec(
|
||||
shape=(5, None, 5),
|
||||
ragged_rank=1,
|
||||
dtype=dtypes.float64,
|
||||
flat_values_spec=tensor_spec.TensorSpec(
|
||||
(5,), dtypes.float64)), (tensor_shape.TensorShape(
|
||||
(5, None, 5)), dtypes.float64, 1, dtypes.int64,
|
||||
tensor_spec.TensorSpec((5,),
|
||||
dtypes.float64))),
|
||||
])
|
||||
def testSerialize(self, rt_spec, expected):
|
||||
serialization = rt_spec._serialize()
|
||||
# TensorShape has an unconventional definition of equality, so we can't use
|
||||
# assertEqual directly here. But repr() is deterministic and lossless for
|
||||
# the expected values, so we can use that instead.
|
||||
self.assertEqual(repr(serialization), repr(expected))
|
||||
|
||||
@parameterized.parameters([
|
||||
(RaggedTensorSpec(
|
||||
ragged_rank=0,
|
||||
shape=[5, 3],
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec([5, 3], dtypes.float32))),
|
||||
[WrappedTensorSpec(tensor_spec.TensorSpec([5, 3], dtypes.float32))]),
|
||||
(RaggedTensorSpec(
|
||||
ragged_rank=1,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec([None, 3], dtypes.float32))),
|
||||
[
|
||||
WrappedTensorSpec(tensor_spec.TensorSpec([None, 3], dtypes.float32)),
|
||||
tensor_spec.TensorSpec([None], dtypes.int64),
|
||||
]),
|
||||
(RaggedTensorSpec(
|
||||
ragged_rank=2,
|
||||
dtype=dtypes.float64,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec([None, 3], dtypes.float64))),
|
||||
[
|
||||
WrappedTensorSpec(tensor_spec.TensorSpec([None, 3], dtypes.float64)),
|
||||
tensor_spec.TensorSpec([None], dtypes.int64),
|
||||
tensor_spec.TensorSpec([None], dtypes.int64),
|
||||
]),
|
||||
(RaggedTensorSpec(
|
||||
shape=[5, None, None],
|
||||
dtype=dtypes.string,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec([None, 3], dtypes.string))),
|
||||
[
|
||||
WrappedTensorSpec(tensor_spec.TensorSpec([None, 3], dtypes.string)),
|
||||
tensor_spec.TensorSpec([6], dtypes.int64),
|
||||
tensor_spec.TensorSpec([None], dtypes.int64),
|
||||
]),
|
||||
])
|
||||
def testComponentSpecs(self, rt_spec, expected):
|
||||
self.assertEqual(rt_spec._component_specs, expected)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'rt_spec':
|
||||
RaggedTensorSpec(
|
||||
shape=[2, None, None],
|
||||
ragged_rank=1,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec(None, dtype=dtypes.float32))),
|
||||
'flat_values': [[1.0, 2.0], [3.0, 4.0]],
|
||||
'nested_row_splits': [[0, 1, 1, 2]],
|
||||
},
|
||||
{
|
||||
'rt_spec':
|
||||
RaggedTensorSpec(
|
||||
shape=[2, None, None],
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec(None, dtype=dtypes.float32))),
|
||||
'flat_values': [1.0, 2.0, 3.0, 4.0],
|
||||
'nested_row_splits': [[0, 2, 4], [0, 2, 3, 3, 4]],
|
||||
},
|
||||
])
|
||||
def testToFromComponents(self, rt_spec, flat_values, nested_row_splits):
|
||||
wrapped_tensor = WrappedTensor(constant_op.constant(flat_values))
|
||||
rt = RaggedTensor.from_nested_row_splits(wrapped_tensor, nested_row_splits)
|
||||
components = rt_spec._to_components(rt)
|
||||
self.assertIsInstance(components[0], WrappedTensor)
|
||||
self.assertAllEqual(components[0].value, wrapped_tensor.value)
|
||||
self.assertAllTensorsEqual(components[1:], nested_row_splits)
|
||||
rt_reconstructed = rt_spec._from_components(components)
|
||||
self.assertIsInstance(rt_reconstructed.flat_values, WrappedTensor)
|
||||
self.assertAllEqual(rt_reconstructed.flat_values.value,
|
||||
wrapped_tensor.value)
|
||||
self.assertAllTensorsEqual(rt_reconstructed.nested_row_splits,
|
||||
rt.nested_row_splits)
|
||||
self.assertEqual(rt_reconstructed.dtype, rt.dtype)
|
||||
|
||||
def testIsCompatibleWith(self):
|
||||
spec1 = RaggedTensorSpec([32, None, None],
|
||||
dtypes.float32,
|
||||
2,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec([None, None],
|
||||
dtypes.float32)))
|
||||
spec2 = RaggedTensorSpec(
|
||||
None,
|
||||
dtypes.float32,
|
||||
2,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec(None, dtypes.float32)))
|
||||
spec3 = RaggedTensorSpec(
|
||||
None,
|
||||
dtypes.int32,
|
||||
1,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec(None, dtypes.int32)))
|
||||
spec4 = RaggedTensorSpec([None],
|
||||
dtypes.int32,
|
||||
0,
|
||||
flat_values_spec=WrappedTensorSpec(
|
||||
tensor_spec.TensorSpec(None, dtypes.int32)))
|
||||
spec5 = RaggedTensorSpec([None], dtypes.int32, 0)
|
||||
|
||||
self.assertTrue(spec1.is_compatible_with(spec2))
|
||||
self.assertFalse(spec1.is_compatible_with(spec3))
|
||||
self.assertFalse(spec1.is_compatible_with(spec4))
|
||||
self.assertFalse(spec2.is_compatible_with(spec3))
|
||||
self.assertFalse(spec2.is_compatible_with(spec4))
|
||||
self.assertFalse(spec3.is_compatible_with(spec4))
|
||||
self.assertFalse(spec4.is_compatible_with(spec5))
|
||||
value = constant_op.constant([1, 2, 3])
|
||||
self.assertFalse(spec4.is_compatible_with(value))
|
||||
self.assertTrue(spec4.is_compatible_with(WrappedTensor(value)))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -155,8 +155,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
'RaggedTensor constructor is private'):
|
||||
RaggedTensor(values=values, row_partition=rp)
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'values must be a Tensor or RaggedTensor'):
|
||||
with self.assertRaisesRegex(
|
||||
TypeError,
|
||||
r"""type\(values\) must be one of: 'Tensor, RaggedTensor.*"""):
|
||||
RaggedTensor(values=range(7), row_partition=rp, internal=True)
|
||||
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
|
125
tensorflow/python/ops/ragged/ragged_tensor_test_ops.py
Normal file
125
tensorflow/python/ops/ragged/ragged_tensor_test_ops.py
Normal file
@ -0,0 +1,125 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""It lists ops of RaggedTensor for the interest of test."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_bitwise_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import string_ops
|
||||
|
||||
|
||||
# Constants listing various op types to test. Each operation
|
||||
# should be included in at least one list below, or tested separately if
|
||||
# necessary (e.g., because it expects additional arguments).
|
||||
UNARY_FLOAT_OPS = [
|
||||
math_ops.abs,
|
||||
math_ops.acos,
|
||||
math_ops.acosh,
|
||||
math_ops.angle,
|
||||
math_ops.asin,
|
||||
math_ops.asinh,
|
||||
math_ops.atan,
|
||||
math_ops.atanh,
|
||||
math_ops.ceil,
|
||||
math_ops.conj,
|
||||
math_ops.cos,
|
||||
math_ops.cosh,
|
||||
math_ops.digamma,
|
||||
math_ops.erf,
|
||||
math_ops.erfc,
|
||||
math_ops.erfinv,
|
||||
math_ops.exp,
|
||||
math_ops.expm1,
|
||||
math_ops.floor,
|
||||
math_ops.imag,
|
||||
math_ops.is_finite,
|
||||
math_ops.is_inf,
|
||||
math_ops.is_nan,
|
||||
math_ops.lgamma,
|
||||
math_ops.log,
|
||||
math_ops.log1p,
|
||||
math_ops.log_sigmoid,
|
||||
math_ops.ndtri,
|
||||
math_ops.negative,
|
||||
math_ops.real,
|
||||
math_ops.reciprocal,
|
||||
math_ops.rint,
|
||||
math_ops.round,
|
||||
math_ops.rsqrt,
|
||||
math_ops.sign,
|
||||
math_ops.sin,
|
||||
math_ops.sinh,
|
||||
math_ops.sqrt,
|
||||
math_ops.square,
|
||||
math_ops.tan,
|
||||
array_ops.identity,
|
||||
array_ops.ones_like,
|
||||
array_ops.zeros_like,
|
||||
]
|
||||
UNARY_BOOL_OPS = [
|
||||
math_ops.logical_not,
|
||||
]
|
||||
UNARY_STRING_OPS = [
|
||||
string_ops.decode_base64,
|
||||
string_ops.encode_base64,
|
||||
string_ops.string_strip,
|
||||
parsing_ops.decode_compressed,
|
||||
]
|
||||
BINARY_FLOAT_OPS = [
|
||||
math_ops.add,
|
||||
math_ops.atan2,
|
||||
math_ops.complex,
|
||||
math_ops.div_no_nan,
|
||||
math_ops.divide,
|
||||
math_ops.equal,
|
||||
math_ops.floordiv,
|
||||
math_ops.floormod,
|
||||
math_ops.greater,
|
||||
math_ops.greater_equal,
|
||||
math_ops.less,
|
||||
math_ops.less_equal,
|
||||
math_ops.maximum,
|
||||
math_ops.minimum,
|
||||
math_ops.multiply,
|
||||
math_ops.not_equal,
|
||||
math_ops.pow,
|
||||
math_ops.realdiv,
|
||||
math_ops.squared_difference,
|
||||
math_ops.subtract,
|
||||
math_ops.truediv,
|
||||
]
|
||||
BINARY_BOOL_OPS = [
|
||||
math_ops.logical_and,
|
||||
math_ops.logical_or,
|
||||
math_ops.logical_xor,
|
||||
]
|
||||
UNARY_INT_OPS = [
|
||||
gen_bitwise_ops.invert,
|
||||
string_ops.unicode_script,
|
||||
]
|
||||
BINARY_INT_OPS = [
|
||||
gen_bitwise_ops.bitwise_and,
|
||||
gen_bitwise_ops.bitwise_or,
|
||||
gen_bitwise_ops.bitwise_xor,
|
||||
gen_bitwise_ops.left_shift,
|
||||
gen_bitwise_ops.right_shift,
|
||||
math_ops.truncatediv,
|
||||
math_ops.truncatemod,
|
||||
]
|
@ -8,6 +8,10 @@ tf_class {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "flat_values_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "ragged_rank"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -26,7 +30,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'ragged_rank\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \"<dtype: \'int64\'>\"], "
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'ragged_rank\', \'row_splits_dtype\', \'flat_values_spec\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -8,6 +8,10 @@ tf_class {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "flat_values_spec"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "ragged_rank"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -26,7 +30,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'ragged_rank\', \'row_splits_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \"<dtype: \'int64\'>\"], "
|
||||
argspec: "args=[\'self\', \'shape\', \'dtype\', \'ragged_rank\', \'row_splits_dtype\', \'flat_values_spec\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_value"
|
||||
|
@ -120,6 +120,7 @@ COMMON_PIP_DEPS = [
|
||||
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
|
||||
"//tensorflow/python/data/experimental/ops:testing",
|
||||
"//tensorflow/python/data/experimental/service:server_lib",
|
||||
"//tensorflow/python/ops/ragged:ragged_tensor_test_ops",
|
||||
"//tensorflow/python/data/kernel_tests:test_base",
|
||||
"//tensorflow/python/debug:debug_pip",
|
||||
"//tensorflow/python/distribute:combinations",
|
||||
|
Loading…
x
Reference in New Issue
Block a user