768 lines
32 KiB
Python
768 lines
32 KiB
Python
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for RaggedTensor operator dispatch."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import ops
|
|
from tensorflow.python.framework import sparse_tensor
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops import clip_ops
|
|
from tensorflow.python.ops import data_flow_ops
|
|
from tensorflow.python.ops import math_ops
|
|
from tensorflow.python.ops import nn_ops
|
|
from tensorflow.python.ops import string_ops
|
|
from tensorflow.python.ops.ragged import ragged_dispatch
|
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
|
from tensorflow.python.ops.ragged import ragged_tensor
|
|
from tensorflow.python.ops.ragged import ragged_tensor_test_ops as test_ops
|
|
from tensorflow.python.platform import googletest
|
|
|
|
|
|
# pylint: disable=g-complex-comprehension
|
|
@test_util.run_all_in_graph_and_eager_modes
|
|
class RaggedDispatchTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|
|
|
def assertSameShape(self, x, y):
|
|
"""Checks that x and y have the same shape (including ragged shapes)."""
|
|
if ragged_tensor.is_ragged(x):
|
|
self.assertTrue(ragged_tensor.is_ragged(y))
|
|
self.assertEqual(x.ragged_rank, y.ragged_rank)
|
|
for (x_splits, y_splits) in zip(x.nested_row_splits, y.nested_row_splits):
|
|
self.assertAllEqual(x_splits, y_splits)
|
|
self.assertAllEqual(
|
|
array_ops.shape(x.flat_values), array_ops.shape(y.flat_values))
|
|
else:
|
|
self.assertIsInstance(y, ops.Tensor)
|
|
self.assertAllEqual(array_ops.shape(x), array_ops.shape(y))
|
|
|
|
@parameterized.parameters(
|
|
#=========================================================================
|
|
# Test different input shapes.
|
|
#=========================================================================
|
|
[
|
|
# 0-dimensional input
|
|
{'x': 12},
|
|
# 1-dimensional input
|
|
{'x': [1, -2, 3]},
|
|
# 2-dimensional input
|
|
{'x': [[-2, 3], [-3, 4]]},
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[-2, 3], [-3]], ragged_rank=1)},
|
|
# 3-dimensional inputs
|
|
{'x': [[[-2, 3], [3, 4]], [[7, 6], [5, 4]]]},
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[-2, 3], [3, 4]], [[7, 6]]],
|
|
ragged_rank=1)},
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[-2, 3, 4], []], [[7, 6]], []],
|
|
ragged_rank=2)},
|
|
] +
|
|
#=========================================================================
|
|
# Test each unary op.
|
|
#=========================================================================
|
|
[{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]), 'op': op}
|
|
for op in test_ops.UNARY_FLOAT_OPS] +
|
|
[{'x': ragged_factory_ops.constant_value([[True, False], [True]]),
|
|
'op': op}
|
|
for op in test_ops.UNARY_BOOL_OPS] +
|
|
[{'x': ragged_factory_ops.constant_value([[18, 512], [12412]], np.int32),
|
|
'op': op}
|
|
for op in test_ops.UNARY_INT_OPS] +
|
|
[{'x': ragged_factory_ops.constant_value([['abcd', 'efgh'],
|
|
['aabbccdd']]),
|
|
'op': op}
|
|
for op in test_ops.UNARY_STRING_OPS] +
|
|
[
|
|
{'op': clip_ops.clip_by_value,
|
|
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
|
'clip_value_min': 0.1, 'clip_value_max': 4.0},
|
|
{'op': math_ops.cast,
|
|
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
|
'dtype': dtypes.int32},
|
|
{'op': math_ops.saturate_cast,
|
|
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
|
'dtype': dtypes.int32},
|
|
{'op': string_ops.string_to_hash_bucket,
|
|
'x': ragged_factory_ops.constant_value(
|
|
[['abcd', 'efgh'], ['aabbccdd']]),
|
|
'num_buckets': 1000},
|
|
{'op': string_ops.string_to_hash_bucket_fast,
|
|
'x': ragged_factory_ops.constant_value(
|
|
[['abcd', 'efgh'], ['aabbccdd']]),
|
|
'num_buckets': 1000},
|
|
{'op': string_ops.string_to_hash_bucket_strong,
|
|
'x': ragged_factory_ops.constant_value(
|
|
[['abcd', 'efgh'], ['aabbccdd']]),
|
|
'num_buckets': 1000,
|
|
'key': [1231, 12512]},
|
|
{'op': string_ops.string_to_number,
|
|
'x': ragged_factory_ops.constant_value([['-2.0', '3.0'], ['-3.0']])},
|
|
{'op': string_ops.regex_full_match,
|
|
'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
|
|
'pattern': r'\w+'},
|
|
{'op': string_ops.regex_replace,
|
|
'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
|
|
'pattern': r'\d',
|
|
'rewrite': '#'},
|
|
{'op': string_ops.substr,
|
|
'x': ragged_factory_ops.constant_value([['hello', '123'], ['1+1']]),
|
|
'pos': 2, 'len': 3},
|
|
{'op': array_ops.check_numerics,
|
|
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
|
'message': 'check-numerics'},
|
|
{'op': nn_ops.dropout,
|
|
'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
|
'rate': 0.5,
|
|
'seed': 1},
|
|
]
|
|
) # pyformat: disable
|
|
def testUnaryElementwiseOp(self, x, op=math_ops.abs, **extra_args):
|
|
if test_util.IsBuiltWithROCm():
|
|
# TODO(rocm):
|
|
# This fails on ROCm...see JIRA ticket 236756
|
|
self.skipTest('Fails on ROCM')
|
|
|
|
result = op(x, **extra_args)
|
|
|
|
# Run the wrapped op on the dense values, for comparison.
|
|
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
|
|
expected_flat_values = array_ops.reshape(op(dense_x, **extra_args), [-1])
|
|
|
|
# Check that the result has the expected shape.
|
|
self.assertSameShape(x, result)
|
|
|
|
# Check that the result has the expected (flattened) values.
|
|
if ragged_tensor.is_ragged(result):
|
|
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
|
else:
|
|
result_flat_values = array_ops.reshape(result, [-1])
|
|
self.assertAllEqual(expected_flat_values, result_flat_values)
|
|
|
|
@parameterized.parameters(
|
|
[
|
|
#=====================================================================
|
|
# Without broadcasting -- i.e., shapes match exactly.
|
|
#=====================================================================
|
|
# Shapes: x:(), y:()
|
|
{'x': 12,
|
|
'y': 8},
|
|
# Shapes: x:(3,), y:(3,)
|
|
{'x': [7, 8, 9],
|
|
'y': [1, -2, 3]},
|
|
# Shapes: x:(2, 2), y:(2, 2)
|
|
{'x': [[-2, 3], [-3, -4]],
|
|
'y': [[1, 2], [3, 4]]},
|
|
# Shapes: x:(2, None), y:(2, None)
|
|
{'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
|
|
'y': ragged_factory_ops.constant_value([[5, 6], [7]])},
|
|
# Shapes: x:(2, 2, 2), y:(2, 2, 2)
|
|
{'x': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
|
|
'y': [[[9, 3], [3, 4]], [[5, 2], [7, 6]]]},
|
|
# Shapes: x:(2, None, None), y: (2, None, None)
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
|
|
'y': ragged_factory_ops.constant_value(
|
|
[[[3, 8], [2], [5]], [[], [1, 9, 8]]])},
|
|
# Shapes: x:(2, None, 2), y: (2, None, 2)
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
|
ragged_rank=1),
|
|
'y': ragged_factory_ops.constant_value(
|
|
[[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
|
ragged_rank=1)},
|
|
|
|
#=====================================================================
|
|
# With broadcasting
|
|
#=====================================================================
|
|
# Shapes: x:(), y:(3,)
|
|
{'x': 12, # Broadcast () -> (3,)
|
|
'y': [1, -2, 3]},
|
|
# Shapes: x:(1,), y:(3,)
|
|
{'x': [12], # Broadcast (1,) -> (3,)
|
|
'y': [1, -2, 3]},
|
|
# Shapes: x:(), y:(2, 2)
|
|
{'x': 12, # Broadcast () -> (2, 2)
|
|
'y': [[1, 2], [3, 4]]},
|
|
# Shapes: x:(1,), y:(2, 2)
|
|
{'x': 12, # Broadcast (1,) -> (2, 2)
|
|
'y': [[1, 2], [3, 4]]},
|
|
# Shapes: x:(2, 1), y:(2, 2)
|
|
{'x': [[10], [20]], # Broadcast (2, 1) -> (2, 2)
|
|
'y': [[1, 2], [3, 4]]},
|
|
# Shapes: x:(), y:(2, None)
|
|
{'x': 10, # Broadcast () -> (2, None)
|
|
'y': ragged_factory_ops.constant_value(
|
|
[[1, 2], [3]], dtype=np.int32)},
|
|
# TODO(edloper): Add tests for more advanced broadcasting, once we add
|
|
# support for it.
|
|
|
|
#=====================================================================
|
|
# Keyword Args
|
|
#=====================================================================
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[1, 2], [3], [4]], [[], [5, 7, 8]]]),
|
|
'y': ragged_factory_ops.constant_value(
|
|
[[[3, 8], [2], [5]], [[], [1, 9, 8]]]),
|
|
'use_kwargs': ('x', 'y')},
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
|
ragged_rank=1),
|
|
'y': ragged_factory_ops.constant_value(
|
|
[[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
|
ragged_rank=1),
|
|
'use_kwargs': ('x', 'y')},
|
|
{'x': ragged_factory_ops.constant_value(
|
|
[[[1, 2]], [[3, 4], [5, 6], [7, 8]]],
|
|
ragged_rank=1),
|
|
'y': ragged_factory_ops.constant_value(
|
|
[[[9, 3]], [[5, 2], [3, 4], [7, 6]]],
|
|
ragged_rank=1),
|
|
'use_kwargs': ('x',)},
|
|
] +
|
|
#=========================================================================
|
|
# Test each binary op.
|
|
#=========================================================================
|
|
[{'x': ragged_factory_ops.constant_value([[-2.0, 3.0], [-3.0]]),
|
|
'y': ragged_factory_ops.constant_value([[5.0, 1.0], [12.0]]),
|
|
'op': op}
|
|
for op in test_ops.BINARY_FLOAT_OPS] +
|
|
[{'x': ragged_factory_ops.constant_value([[-2, 3], [-3]]),
|
|
'y': ragged_factory_ops.constant_value([[5, 1], [12]]),
|
|
'op': op}
|
|
for op in test_ops.BINARY_INT_OPS] +
|
|
[{'x': ragged_factory_ops.constant_value([[True, True], [False]]),
|
|
'y': ragged_factory_ops.constant_value([[False, True], [False]]),
|
|
'op': op}
|
|
for op in test_ops.BINARY_BOOL_OPS]
|
|
) # pyformat: disable
|
|
def testBinaryElementwiseOp(self, x, y, op=math_ops.add, **extra_args):
|
|
use_kwargs = extra_args.pop('use_kwargs', ())
|
|
if 'x' in use_kwargs and 'y' in use_kwargs:
|
|
result = op(x=x, y=y, **extra_args)
|
|
elif 'y' in use_kwargs:
|
|
result = op(x, y=y, **extra_args)
|
|
else:
|
|
result = op(x, y, **extra_args)
|
|
|
|
# Run the wrapped op on the dense values, for comparison.
|
|
dense_x = x.flat_values if ragged_tensor.is_ragged(x) else x
|
|
dense_y = y.flat_values if ragged_tensor.is_ragged(y) else y
|
|
expected_flat_values = array_ops.reshape(
|
|
op(dense_x, dense_y, **extra_args), [-1])
|
|
|
|
# Check that the result has the expected shape.
|
|
self.assertSameShape(y, result)
|
|
|
|
# Check that the result has the expected (flattened) values.
|
|
if ragged_tensor.is_ragged(result):
|
|
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
|
else:
|
|
result_flat_values = array_ops.reshape(result, [-1])
|
|
self.assertAllEqual(expected_flat_values, result_flat_values)
|
|
|
|
@parameterized.parameters(
|
|
[
|
|
{'inputs': (12, 8, 3)},
|
|
{'inputs': ([1, 2, 3], [7, 8, 9], [3, 6, 9])},
|
|
{'inputs': ([[1, 2]], [[3, 4]], [[5, 6]])},
|
|
{'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
|
|
ragged_factory_ops.constant_value([[4, 7], [88]]),
|
|
ragged_factory_ops.constant_value([[2, 9], [12]]))},
|
|
{'inputs': (ragged_factory_ops.constant_value(
|
|
[[[1, 3], [-3]], [[1]]]),
|
|
ragged_factory_ops.constant_value(
|
|
[[[4, 7], [88]], [[2]]]),
|
|
ragged_factory_ops.constant_value(
|
|
[[[2, 9], [12]], [[8]]]))},
|
|
{'inputs': (
|
|
ragged_factory_ops.constant_value([[[1, 3], [3, 4]], [[1, 5]]],
|
|
ragged_rank=1),
|
|
ragged_factory_ops.constant_value([[[4, 7], [1, 2]], [[2, 2]]],
|
|
ragged_rank=1),
|
|
ragged_factory_ops.constant_value([[[2, 9], [5, 2]], [[8, 0]]],
|
|
ragged_rank=1))},
|
|
{'inputs': (
|
|
ragged_factory_ops.constant_value([[[1, 3], [-3]], [[1]]]),
|
|
ragged_factory_ops.constant_value([[[4, 7], [88]], [[2]]]),
|
|
ragged_factory_ops.constant_value([[[2, 9], [12]], [[8]]])),
|
|
'use_kwargs': True},
|
|
] + [
|
|
{'op': math_ops.add_n,
|
|
'inputs': (ragged_factory_ops.constant_value([[1, 3], [-3]]),
|
|
ragged_factory_ops.constant_value([[4, 7], [88]]),
|
|
ragged_factory_ops.constant_value([[2, 9], [12]]))},
|
|
{'op': string_ops.string_join,
|
|
'inputs': (
|
|
ragged_factory_ops.constant_value([['a', 'b'], ['c']]),
|
|
ragged_factory_ops.constant_value([['foo', 'bar'], ['baz']]),
|
|
ragged_factory_ops.constant_value([['2', '9'], ['12']]))},
|
|
]) # pyformat: disable
|
|
def testListValuedElementwiseOp(self, inputs, op=math_ops.add_n,
|
|
**extra_args):
|
|
use_kwargs = extra_args.pop('use_kwargs', False)
|
|
if use_kwargs:
|
|
result = op(inputs=inputs, **extra_args)
|
|
else:
|
|
result = op(inputs, **extra_args)
|
|
|
|
# Run the wrapped op on the dense values, for comparison.
|
|
dense_inputs = [
|
|
x.flat_values if ragged_tensor.is_ragged(x) else x for x in inputs
|
|
]
|
|
expected_flat_values = array_ops.reshape(
|
|
op(dense_inputs, **extra_args), [-1])
|
|
|
|
# Check that the result has the expected shape.
|
|
self.assertSameShape(inputs[0], result)
|
|
|
|
# Check that the result has the expected (flattened) values.
|
|
if ragged_tensor.is_ragged(result):
|
|
result_flat_values = array_ops.reshape(result.flat_values, [-1])
|
|
else:
|
|
result_flat_values = array_ops.reshape(result, [-1])
|
|
self.assertAllEqual(expected_flat_values, result_flat_values)
|
|
|
|
def testElementwiseOpUnknownRankError(self):
|
|
if context.executing_eagerly():
|
|
return
|
|
x = ragged_factory_ops.constant([[1, 2], [3]])
|
|
y = ragged_tensor.RaggedTensor.from_row_splits(
|
|
array_ops.placeholder_with_default([1, 2, 3], shape=None), x.row_splits)
|
|
with self.assertRaisesRegex(ValueError,
|
|
r'Unable to broadcast: unknown rank'):
|
|
math_ops.add(x, y)
|
|
|
|
@parameterized.parameters([
|
|
dict(
|
|
x=ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
y=[[10]],
|
|
expected=[[11, 12], [13]]),
|
|
dict(
|
|
x=ragged_factory_ops.constant_value([[[1, 2], [3, 4]], [[5]]],
|
|
ragged_rank=2),
|
|
y=ragged_factory_ops.constant_value([[[10], [20]], [[30]]],
|
|
ragged_rank=1),
|
|
expected=[[[11, 12], [23, 24]], [[35]]]),
|
|
dict(
|
|
x=ragged_factory_ops.constant_value([[[1]]]),
|
|
y=ragged_factory_ops.constant_value([[1]]),
|
|
expected=[[[2]]]),
|
|
])
|
|
def testElementwiseOpBroadcast(self, x, y, expected):
|
|
x = ragged_tensor.convert_to_tensor_or_ragged_tensor(x, dtype=dtypes.int32)
|
|
y = ragged_tensor.convert_to_tensor_or_ragged_tensor(y, dtype=dtypes.int32)
|
|
result = x + y
|
|
self.assertAllEqual(result, expected)
|
|
|
|
def testElementwiseOpShapeMismatch(self):
|
|
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
|
|
y = ragged_factory_ops.constant([[1, 2, 3], [4, 5, 6]])
|
|
with self.assertRaises((ValueError, errors.InvalidArgumentError)):
|
|
self.evaluate(math_ops.add(x, y))
|
|
|
|
def testBinaryOpSparseAndRagged(self):
|
|
x = ragged_factory_ops.constant([[1, 2, 3], [4, 5]])
|
|
y = sparse_tensor.SparseTensor([[0, 0], [0, 1], [2, 0]], [1, 2, 3], [3, 2])
|
|
with self.assertRaises((TypeError, ValueError)):
|
|
self.evaluate(math_ops.add(x, y))
|
|
|
|
with self.assertRaises((TypeError, ValueError)):
|
|
self.evaluate(math_ops.add_n([x, y]))
|
|
|
|
@parameterized.parameters([
|
|
dict(
|
|
op=array_ops.batch_gather,
|
|
args=(ragged_factory_ops.constant_value([[5, 6, 7], [8, 9]]),
|
|
ragged_factory_ops.constant_value([[2, 1, 0], [1]])),
|
|
expected=ragged_factory_ops.constant_value([[7, 6, 5], [9]])),
|
|
dict(
|
|
op=array_ops.concat,
|
|
args=([
|
|
ragged_factory_ops.constant_value([[1, 2, 3], [4]],
|
|
dtype=np.int32),
|
|
np.array([[5, 6]], dtype=np.int32)
|
|
],),
|
|
kwargs={'axis': 0},
|
|
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4], [5, 6]])),
|
|
dict(
|
|
op=array_ops.expand_dims,
|
|
kwargs={
|
|
'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'axis': 0
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[[1, 2], [3]]])),
|
|
dict(
|
|
op=array_ops.expand_dims_v2,
|
|
kwargs={
|
|
'input': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'axis': -1
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[[1], [2]], [[3]]],
|
|
ragged_rank=1),
|
|
),
|
|
dict(
|
|
op=array_ops.gather,
|
|
kwargs={
|
|
'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'indices': [1, 0, 1]
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[3], [1, 2], [3]])),
|
|
dict(
|
|
op=array_ops.gather_v2,
|
|
kwargs={
|
|
'params': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'indices': ragged_factory_ops.constant_value([[1, 0], [1]])
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[[3], [1, 2]], [[3]]])),
|
|
dict(
|
|
op=array_ops.gather_nd,
|
|
kwargs={
|
|
'params': ragged_factory_ops.constant_value([[7, 8], [9]]),
|
|
'indices': [[0, 1], [1, 0], [0, 0]]
|
|
},
|
|
expected=ragged_factory_ops.constant_value([8, 9, 7])),
|
|
dict(
|
|
op=array_ops.one_hot,
|
|
kwargs={
|
|
'indices':
|
|
ragged_factory_ops.constant_value([[1, 2, 3], [0]],
|
|
dtype=np.int32),
|
|
'depth':
|
|
4,
|
|
'axis':
|
|
-1
|
|
},
|
|
expected=ragged_factory_ops.constant_value(
|
|
[[[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], [[1, 0, 0, 0]]],
|
|
ragged_rank=1)),
|
|
dict(
|
|
op=array_ops.stack,
|
|
args=([
|
|
ragged_factory_ops.constant_value([[1, 2, 3], [4]],
|
|
dtype=np.int32),
|
|
np.array([[5, 6]], dtype=np.int32)
|
|
],),
|
|
expected=ragged_factory_ops.constant_value([[[1, 2, 3], [4]],
|
|
[[5, 6]]])),
|
|
dict(
|
|
op=array_ops.tile,
|
|
args=([
|
|
ragged_factory_ops.constant_value([[1, 2], [3]], dtype=np.int32),
|
|
[2, 3]
|
|
]),
|
|
expected=ragged_factory_ops.constant_value([[1, 2, 1, 2, 1, 2],
|
|
[3, 3, 3],
|
|
[1, 2, 1, 2, 1, 2],
|
|
[3, 3, 3]])),
|
|
dict(
|
|
op=array_ops.where,
|
|
args=(ragged_factory_ops.constant_value([[True, False], [True]]),
|
|
ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]),
|
|
ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])),
|
|
expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])),
|
|
dict(
|
|
op=array_ops.where,
|
|
args=(ragged_factory_ops.constant_value([[True, False], [True]]),),
|
|
expected=[[0, 0], [1, 0]]),
|
|
dict(
|
|
op=array_ops.where_v2,
|
|
args=(ragged_factory_ops.constant_value([[True, False], [True]]),
|
|
ragged_factory_ops.constant_value([[b'A', b'B'], [b'C']]),
|
|
ragged_factory_ops.constant_value([[b'a', b'b'], [b'c']])),
|
|
expected=ragged_factory_ops.constant_value([[b'A', b'b'], [b'C']])),
|
|
dict(
|
|
op=math_ops.unsorted_segment_sum,
|
|
kwargs={
|
|
'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]),
|
|
'num_segments': 3
|
|
},
|
|
expected=[4, 0, 2]),
|
|
dict(
|
|
op=math_ops.unsorted_segment_prod,
|
|
kwargs={
|
|
'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'segment_ids': ragged_factory_ops.constant_value([[0, 2], [0]]),
|
|
'num_segments': 3
|
|
},
|
|
expected=[3, 1, 2]),
|
|
dict(
|
|
op=math_ops.unsorted_segment_min,
|
|
kwargs={
|
|
'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
|
|
'num_segments': 2
|
|
},
|
|
expected=[1, 2]),
|
|
dict(
|
|
op=math_ops.unsorted_segment_max,
|
|
kwargs={
|
|
'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
|
|
'num_segments': 2
|
|
},
|
|
expected=[3, 2]),
|
|
dict(
|
|
op=math_ops.unsorted_segment_mean,
|
|
kwargs={
|
|
'data': ragged_factory_ops.constant_value([[1, 2], [3]]),
|
|
'segment_ids': ragged_factory_ops.constant_value([[0, 1], [0]]),
|
|
'num_segments': 2
|
|
},
|
|
expected=[2, 2]),
|
|
dict(
|
|
op=math_ops.unsorted_segment_sqrt_n,
|
|
kwargs={
|
|
'data':
|
|
ragged_factory_ops.constant_value([[1.0, 2.0],
|
|
[3.0, 4.0, 6.0]]),
|
|
'segment_ids':
|
|
ragged_factory_ops.constant_value([[0, 1], [0, 0, 0]]),
|
|
'num_segments':
|
|
2
|
|
},
|
|
expected=[7.0, 2.0]),
|
|
dict(
|
|
op=math_ops.reduce_sum,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[3, 12]),
|
|
dict(
|
|
op=math_ops.reduce_prod,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[2, 60]),
|
|
dict(
|
|
op=math_ops.reduce_min,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[1, 3]),
|
|
dict(
|
|
op=math_ops.reduce_max,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[1, 2], [3, 4, 5]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[2, 5]),
|
|
dict(
|
|
op=math_ops.reduce_mean,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[1, 3], [3, 4, 5]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[2, 4]),
|
|
dict(
|
|
op=math_ops.reduce_any,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[True, False],
|
|
[True, True, True]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[True, True]),
|
|
dict(
|
|
op=string_ops.reduce_join,
|
|
kwargs={
|
|
'inputs':
|
|
ragged_factory_ops.constant_value([[
|
|
b'this', b'is', b'a', b'test', b'for', b'ragged',
|
|
b'tensors'
|
|
], [b'please', b'do', b'not', b'panic', b'!']]),
|
|
'axis':
|
|
0,
|
|
'keepdims':
|
|
False,
|
|
'separator':
|
|
''
|
|
},
|
|
expected=[
|
|
b'thisplease', b'isdo', b'anot', b'testpanic', b'for!', b'ragged',
|
|
b'tensors'
|
|
]),
|
|
dict(
|
|
op=math_ops.reduce_all,
|
|
kwargs={
|
|
'input_tensor':
|
|
ragged_factory_ops.constant_value([[True, False],
|
|
[True, True, True]]),
|
|
'axis':
|
|
1
|
|
},
|
|
expected=[False, True]),
|
|
dict(
|
|
op=array_ops.rank,
|
|
kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
|
|
expected=2),
|
|
dict(
|
|
op=array_ops.size,
|
|
kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
|
|
expected=3),
|
|
dict(
|
|
op=array_ops.size_v2,
|
|
kwargs={'input': ragged_factory_ops.constant_value([[8, 3], [5]])},
|
|
expected=3),
|
|
dict(
|
|
op=array_ops.squeeze,
|
|
kwargs={
|
|
'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
|
|
'axis': [0]
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
|
|
dict(
|
|
op=array_ops.squeeze_v2,
|
|
kwargs={
|
|
'input': ragged_factory_ops.constant_value([[[1, 2, 3], [4, 5]]]),
|
|
'axis': [0]
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]])),
|
|
dict(
|
|
op=data_flow_ops.dynamic_partition,
|
|
kwargs={
|
|
'data': ragged_factory_ops.constant_value([[1], [2, 3, 4], [5]]),
|
|
'partitions': [2, 1, 1],
|
|
'num_partitions': 3
|
|
},
|
|
expected=[
|
|
ragged_factory_ops.constant_value([], ragged_rank=1),
|
|
ragged_factory_ops.constant_value([[2, 3, 4], [5]]),
|
|
ragged_factory_ops.constant_value([[1]])
|
|
],
|
|
result_is_list=True),
|
|
dict(
|
|
op=array_ops.reverse,
|
|
kwargs={
|
|
'tensor': ragged_factory_ops.constant_value([[1, 2, 3], [4, 5]]),
|
|
'axis': [0, -1]
|
|
},
|
|
expected=ragged_factory_ops.constant_value([[5, 4], [3, 2, 1]])),
|
|
dict(
|
|
op=string_ops.string_format,
|
|
kwargs={'template': 'Hi {}',
|
|
'inputs': [ragged_factory_ops.constant_value([[1, 2], [3]])]},
|
|
expected='Hi [[1, 2], [3]]'),
|
|
])
|
|
def testRaggedDispatch(self, op, expected, args=(), result_is_list=False,
|
|
kwargs=None):
|
|
if kwargs is None: kwargs = {}
|
|
result = op(*args, **kwargs)
|
|
if result_is_list:
|
|
self.assertLen(result, len(expected))
|
|
for (r, e) in zip(result, expected):
|
|
self.assertAllEqual(r, e)
|
|
else:
|
|
self.assertAllEqual(result, expected)
|
|
|
|
def testUnaryElementwiseOpsPreserveUniformRowLength(self):
|
|
# Unary elementwise op
|
|
rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
|
ragged_factory_ops.constant([[1, 2], [3]]),
|
|
uniform_row_length=2)
|
|
self.assertAllEqual(rt.uniform_row_length,
|
|
array_ops.zeros_like(rt).uniform_row_length)
|
|
|
|
# Unary-list elementwise op
|
|
rt = ragged_tensor.RaggedTensor.from_uniform_row_length(
|
|
ragged_factory_ops.constant([[1, 2], [3]]),
|
|
uniform_row_length=2)
|
|
self.assertAllEqual(rt.uniform_row_length,
|
|
math_ops.add_n([rt, rt]).uniform_row_length)
|
|
|
|
def test_ragged_op_list(self):
|
|
# Ops that should be listed as supported in both v1 and v2.
|
|
supported_ops = [
|
|
'bitwise.bitwise_and', 'bitwise.bitwise_or', 'bitwise.bitwise_xor',
|
|
'bitwise.invert', 'bitwise.left_shift', 'bitwise.right_shift',
|
|
'clip_by_value', 'concat', 'debugging.check_numerics', 'cast',
|
|
'dtypes.complex', 'dtypes.saturate_cast', 'expand_dims', 'gather_nd',
|
|
'gather', 'identity', 'io.decode_base64', 'io.decode_compressed',
|
|
'io.encode_base64', 'math.abs', 'math.acos', 'math.acosh', 'math.add_n',
|
|
'math.add', 'math.angle', 'math.asin', 'math.asinh', 'math.atan2',
|
|
'math.atan', 'math.atanh', 'math.ceil', 'math.conj', 'math.cos',
|
|
'math.cosh', 'math.digamma', 'math.divide_no_nan', 'math.divide',
|
|
'math.equal', 'math.erf', 'math.erfc', 'math.exp', 'math.expm1',
|
|
'math.floor', 'math.floordiv', 'math.floormod', 'math.greater_equal',
|
|
'math.greater', 'math.imag', 'math.is_finite', 'math.is_inf',
|
|
'math.is_nan', 'math.less_equal', 'math.less', 'math.lgamma',
|
|
'math.log1p', 'math.log_sigmoid', 'math.log', 'math.logical_and',
|
|
'math.logical_not', 'math.logical_or', 'math.logical_xor',
|
|
'math.maximum', 'math.minimum', 'math.multiply', 'math.negative',
|
|
'math.not_equal', 'math.pow', 'math.real', 'math.reciprocal',
|
|
'math.reduce_any', 'math.reduce_max', 'math.reduce_mean',
|
|
'math.reduce_min', 'math.reduce_prod', 'math.reduce_sum', 'math.rint',
|
|
'math.round', 'math.rsqrt', 'math.sign', 'math.sin', 'math.sinh',
|
|
'math.sqrt', 'math.square', 'math.squared_difference', 'math.subtract',
|
|
'math.tan', 'math.truediv', 'math.unsorted_segment_max',
|
|
'math.unsorted_segment_mean', 'math.unsorted_segment_min',
|
|
'math.unsorted_segment_prod', 'math.unsorted_segment_sqrt_n',
|
|
'math.unsorted_segment_sum', 'one_hot', 'ones_like', 'rank', 'realdiv',
|
|
'math.reduce_all', 'size', 'squeeze', 'stack', 'strings.as_string',
|
|
'strings.join', 'strings.length', 'strings.reduce_join',
|
|
'strings.regex_full_match', 'strings.regex_replace', 'strings.strip',
|
|
'strings.substr', 'strings.to_hash_bucket_fast',
|
|
'strings.to_hash_bucket_strong', 'strings.to_hash_bucket',
|
|
'strings.to_number', 'strings.unicode_script', 'tile', 'truncatediv',
|
|
'truncatemod', 'zeros_like', 'dynamic_partition', 'reverse',
|
|
'nn.dropout', 'strings.format', 'print'
|
|
]
|
|
|
|
# Ops that should be listed as supported in v1 only.
|
|
supported_ops_v1 = ['batch_gather']
|
|
|
|
# Ops that should be listed as supported in v2 only.
|
|
supported_ops_v2 = []
|
|
|
|
v1_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=1)
|
|
for element in supported_ops + supported_ops_v1:
|
|
self.assertIn('`tf.' + element + '`', v1_ragged_ops)
|
|
for element in supported_ops_v2:
|
|
self.assertNotIn('`tf.' + element + '`', v1_ragged_ops)
|
|
|
|
v2_ragged_ops = ragged_dispatch.ragged_op_list(tf_version=2)
|
|
for element in supported_ops + supported_ops_v2:
|
|
self.assertIn('`tf.' + element + '`', v2_ragged_ops)
|
|
for element in supported_ops_v1:
|
|
self.assertNotIn('`tf.' + element + '`', v2_ragged_ops)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
googletest.main()
|