Pull partition (row_splits, value_rowid, et cetera) of a RaggedTensor into a separate object. This allows one RowPartition to be shared by multiple RaggedTensor objects.

Copied over tests from RaggedTensorTest to RowPartitionTest.

PiperOrigin-RevId: 298916432
Change-Id: Ibf05c1b5969c0f4d3d21b301b5ee99e7ca1350e1
This commit is contained in:
A. Unique TensorFlower 2020-03-04 12:48:00 -08:00 committed by TensorFlower Gardener
parent 0b0432ae2d
commit 6c26c995db
9 changed files with 1947 additions and 569 deletions

View File

@ -269,6 +269,27 @@ py_library(
srcs_version = "PY2AND3",
)
py_library(
name = "row_partition",
srcs = ["row_partition.py"],
srcs_version = "PY2AND3",
deps = [
":segment_id_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python:util",
"//third_party/py/numpy",
],
)
py_library(
name = "ragged_tensor",
srcs = ["ragged_tensor.py"],
@ -277,7 +298,7 @@ py_library(
":ragged_config",
":ragged_tensor_value",
":ragged_util",
":segment_id_ops",
":row_partition",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:composite_tensor",
@ -292,6 +313,7 @@ py_library(
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python:tensor_util",
"//tensorflow/python:tf2",
"//tensorflow/python:type_spec",
"//tensorflow/python:util",
"//third_party/py/numpy",
@ -449,12 +471,44 @@ py_test(
":ragged_tensor_value",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_spec",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
py_test(
name = "row_partition_test",
size = "medium",
timeout = "long",
srcs = ["row_partition_test.py"],
python_version = "PY3",
shard_count = 4,
srcs_version = "PY2AND3",
tags = [
"no_windows",
],
deps = [
":ragged", # fixdeps: keep
":row_partition",
"//tensorflow/python:array_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/eager:context",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",

View File

@ -342,5 +342,5 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None):
for i in reversed(range(ragged_rank)):
row_splits = array_ops.placeholder(dtypes.int64, [None],
"row_splits_%d" % i)
result = ragged_tensor.RaggedTensor(result, row_splits, internal=True)
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits)
return result

View File

@ -37,22 +37,27 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
(dtypes.int32, 1, [], 'ph',
'tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), '
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64))'),
(dtypes.string, 1, [5], 'ph',
'tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), '
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64))'),
(dtypes.float32, 2, [], 'ph',
'tf.RaggedTensor(values=tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), '
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
'shape=(None,), dtype=int64))'),
(dtypes.int32, 2, [3, 5], 'ph',
'tf.RaggedTensor(values=tf.RaggedTensor('
'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), '
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
'shape=(None,), dtype=int64)), '
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
'shape=(None,), dtype=int64))'),
])
def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name,
expected):

File diff suppressed because it is too large Load Diff

View File

@ -40,6 +40,8 @@ from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.platform import googletest
@ -128,8 +130,7 @@ def int32array(values):
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
#=============================================================================
@ -161,25 +162,22 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
outer_rt = RaggedTensor.from_row_splits(
values=inner_rt, row_splits=[0, 3, 3, 5])
self.assertEqual(outer_rt.ragged_rank, 2)
self.assertAllEqual(
outer_rt,
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
self.assertAllEqual(outer_rt,
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del inner_rt, outer_rt
# From section: "Multiple Ragged Dimensions"
rt = RaggedTensor.from_nested_row_splits(
flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
self.assertAllEqual(
rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
del rt
# From section: "Uniform Inner Dimensions"
rt = RaggedTensor.from_row_splits(
values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
self.assertAllEqual(
rt,
[[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
self.assertEqual(rt.shape.as_list(), [2, None, 3])
del rt
@ -223,42 +221,29 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
def testRaggedTensorConstruction(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rt = RaggedTensor(values=values, row_splits=row_splits, internal=True)
rp = RowPartition(row_splits=row_splits, internal=True)
rt = RaggedTensor(values=values, row_partition=rp, internal=True)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testRaggedTensorConstructionErrors(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rp = RowPartition(row_splits=row_splits, internal=True)
with self.assertRaisesRegexp(ValueError,
'RaggedTensor constructor is private'):
RaggedTensor(values=values, row_splits=row_splits)
RaggedTensor(values=values, row_partition=rp)
with self.assertRaisesRegexp(TypeError,
'values must be a Tensor or RaggedTensor'):
RaggedTensor(values=range(7), row_splits=row_splits, internal=True)
RaggedTensor(values=range(7), row_partition=rp, internal=True)
with self.assertRaisesRegexp(TypeError,
'Row-partitioning argument must be a Tensor'):
RaggedTensor(values=values, row_splits=[0, 2, 2, 5, 6, 7], internal=True)
with self.assertRaisesRegexp(ValueError,
r'Shape \(6, 1\) must have rank 1'):
RaggedTensor(
values=values,
row_splits=array_ops.expand_dims(row_splits, 1),
internal=True)
with self.assertRaisesRegexp(TypeError,
'Cached value must be a Tensor or None.'):
RaggedTensor(
values=values,
row_splits=row_splits,
cached_row_lengths=[2, 3, 4],
internal=True)
'row_partition must be a RowPartition'):
RaggedTensor(values=values, row_partition=[0, 2, 2, 5, 6, 7],
internal=True)
#=============================================================================
# RaggedTensor Factory Ops
@ -282,9 +267,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
# nrows is not known at graph creation time.
@ -308,17 +292,16 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithExplicitNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(7, dtypes.int64)
rt = RaggedTensor.from_value_rowids(values, value_rowids, nrows,
validate=False)
rt = RaggedTensor.from_value_rowids(
values, value_rowids, nrows, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [7, None])
self.assertEqual(rt.ragged_rank, 1)
@ -331,16 +314,15 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(5, dtypes.int64)
rt = RaggedTensor.from_value_rowids(values, value_rowids, nrows,
validate=False)
rt = RaggedTensor.from_value_rowids(
values, value_rowids, nrows, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [5, None])
self.assertEqual(rt.ragged_rank, 1)
@ -354,9 +336,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, nrows)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromValueRowIdsWithEmptyValues(self):
rt = RaggedTensor.from_value_rowids([], [])
@ -385,9 +366,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_values, values)
self.assertIs(rt_row_splits, row_splits)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowSplitsWithDifferentSplitTypes(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -428,9 +408,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_values, values)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_starts, row_starts)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLimits(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -448,9 +427,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_values, values)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_limits, row_limits)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLengths(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -469,21 +447,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_lengths, row_lengths)
self.assertAllEqual(
rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertAllEqual(rt,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
def testFromRowLengthsInt32(self):
rt = RaggedTensor.from_row_lengths([1, 2, 3, 4],
constant_op.constant([1, 0, 3],
dtype=dtypes.int32))
rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0])
self.assertAllEqual([2, 1, 0], rt2.row_lengths())
def testFromUniformRowLength(self):
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
a1 = RaggedTensor.from_uniform_row_length(values, 2)
a2 = RaggedTensor.from_uniform_row_length(values, 2, 8)
self.assertAllEqual(a1, [[1, 2], [3, 4], [5, 6], [7, 8],
[9, 10], [11, 12], [13, 14], [15, 16]])
self.assertAllEqual(
a1,
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])
self.assertAllEqual(a1, a2)
self.assertEqual(a1.shape.as_list(), [8, 2])
self.assertEqual(a2.shape.as_list(), [8, 2])
self.assertAllEqual(a1.uniform_row_length, 2)
b1 = RaggedTensor.from_uniform_row_length(a1, 2)
b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4)
@ -492,7 +476,6 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(b1, b2)
self.assertEqual(b1.shape.as_list(), [4, 2, 2])
self.assertEqual(b2.shape.as_list(), [4, 2, 2])
self.assertAllEqual(b1.uniform_row_length, 2)
c1 = RaggedTensor.from_uniform_row_length(b1, 2)
c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2)
@ -501,13 +484,11 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(c1, c2)
self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2])
self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2])
self.assertAllEqual(c1.uniform_row_length, 2)
def testFromUniformRowLengthWithEmptyValues(self):
empty_values = []
a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10)
self.assertEqual(a.shape.as_list(), [10, 0])
self.assertAllEqual(a.uniform_row_length, 0)
b = RaggedTensor.from_uniform_row_length(a, 2)
self.assertEqual(b.shape.as_list(), [5, 2, 0])
@ -570,8 +551,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
self.assertAllEqual(
rt,
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedValueRowIdsWithExplicitNRows(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -602,9 +582,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
self.assertAllEqual(rt_nrows, nrows[0])
self.assertAllEqual(rt_values_nrows, nrows[1])
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
[[b'f'], [b'g'], []], [], []])
self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
[[b'f'], [b'g'], []], [], []])
def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
@ -635,8 +614,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
]
rt = RaggedTensor.from_nested_row_splits(flat_values, nested_row_splits,
validate=False)
rt = RaggedTensor.from_nested_row_splits(
flat_values, nested_row_splits, validate=False)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [4, None, None])
self.assertEqual(rt.ragged_rank, 2)
@ -650,8 +629,34 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertIs(rt_row_splits, nested_row_splits[0])
self.assertIs(rt_values_row_splits, nested_row_splits[1])
self.assertAllEqual(
rt,
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testWithRowSplits(self):
flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
nested_row_splits = [
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
]
rt = RaggedTensor.from_nested_row_splits(
flat_values, nested_row_splits, validate=False)
rt = rt.with_row_splits_dtype(dtypes.int32)
self.assertEqual(rt.dtype, dtypes.string)
self.assertEqual(rt.shape.as_list(), [4, None, None])
self.assertEqual(rt.ragged_rank, 2)
rt_values = rt.values
rt_row_splits = rt.row_splits
rt_values_values = rt_values.values
rt_values_row_splits = rt_values.row_splits
self.assertAllEqual(rt_values_values, flat_values)
self.assertAllEqual(rt_row_splits, nested_row_splits[0])
self.assertAllEqual(rt_values_row_splits, nested_row_splits[1])
self.assertAllEqual(
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedRowSplitsWithNonListInput(self):
with self.assertRaisesRegexp(TypeError,
@ -747,16 +752,13 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
for rt in [rt1, rt2]:
self.assertAllEqual(
rt,
[[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], [[10, 11]],
[[12, 13]]])
self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]],
[[10, 11]], [[12, 13]]])
self.assertAllEqual(
rt.values,
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
self.assertEqual(rt.values.shape.dims[0].value, 7)
self.assertAllEqual(
rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
self.assertAllEqual(rt.nrows(), 5)
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
@ -786,11 +788,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
for rt in [rt1, rt2]:
self.assertAllEqual(
rt,
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
self.assertAllEqual(
rt.values,
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
self.assertEqual(rt.values.shape.dims[0].value, 5)
self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3])
self.assertAllEqual(rt.nrows(), 4)
@ -798,9 +798,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3])
self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5])
self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2])
self.assertAllEqual(
rt.flat_values,
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertAllEqual(rt.flat_values,
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
self.assertLen(rt.nested_row_splits, 2)
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
@ -1024,8 +1023,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
'slice offsets must be integers or None'),
# Tests for other errors
(SLICE_BUILDER[..., 0, 0, 0], IndexError,
'Too many indices for RaggedTensor'),
(SLICE_BUILDER[..., 0, 0,
0], IndexError, 'Too many indices for RaggedTensor'),
)
def testRaggedTensorGetItemErrorsWithRaggedRank1(self, slice_spec, expected,
message):
@ -1106,9 +1105,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
[[v[::-2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[..., ::-1, :],
[[v[::-1] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[..., ::-1],
[[[v[::-1] for v in col] for col in row]
for row in EXAMPLE_RAGGED_TENSOR_4D]),
(SLICE_BUILDER[..., ::-1], [[[v[::-1] for v in col] for col in row]
for row in EXAMPLE_RAGGED_TENSOR_4D]),
)
def testRaggedTensorGetItemWithRaggedRank2(self, slice_spec, expected):
"""Test that rt.__getitem__(slice_spec) == expected."""
@ -1212,11 +1210,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
self.assertAllEqual(
rt,
[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
rt, [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
self.assertAllEqual(
rt_newaxis0,
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
rt_newaxis0, [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
self.assertAllEqual(
rt_newaxis1,
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
@ -1330,9 +1326,10 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
else:
expected_repr = (
'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", '
'shape=(7,), dtype=string), row_splits='
'Tensor("RaggedFromRowSplits/row_splits:0", '
'shape=(6,), dtype={}))').format(splits_type)
'shape=(7,), dtype=string), '
'row_splits=Tensor('
'"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",'
' shape=(6,), dtype={}))').format(splits_type)
self.assertEqual(repr(rt), expected_repr)
self.assertEqual(str(rt), expected_repr)
@ -1362,15 +1359,11 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10)
rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
self.assertAllEqual(
rt1_plus_10,
[[11, 12], [13, 14, 15], [16], [], [17]])
self.assertAllEqual(
rt2_times_10,
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
self.assertAllEqual(
rt1_expanded,
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]])
self.assertAllEqual(rt2_times_10,
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
self.assertAllEqual(rt1_expanded,
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
#=============================================================================
# Session.run
@ -1465,6 +1458,99 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
ragged_math_ops.reduce_sum(a)
self.assertLen(a.consumers(), 1)
@parameterized.parameters([
{
'descr': 'from_value_rowids',
'factory': RaggedTensor.from_value_rowids,
'test': RaggedTensor.value_rowids,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'value_rowids': [0, 0, 1, 1, 2, 2],
},
'tensor_field': 'value_rowids',
'value_rowids': [0, 1, 2],
'nrows': 10
},
{
'descr': 'from_row_splits',
'factory': RaggedTensor.from_row_splits,
# row_splits is a property, not a function.
'test': (lambda rt: rt.row_splits),
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_splits': [0, 2, 4, 6],
},
'tensor_field': 'row_splits',
'row_splits': [0, 1, 2, 3]
},
{
'descr': 'from_row_lengths',
'factory': RaggedTensor.from_row_lengths,
'test': RaggedTensor.row_lengths,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_lengths': [2, 2, 2],
},
'tensor_field': 'row_lengths',
'row_lengths': [1, 1, 1],
},
# from_row_starts
{
'descr': 'from_row_starts',
'factory': RaggedTensor.from_row_starts,
'test': RaggedTensor.row_starts,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_starts': [0, 2, 4]
},
'tensor_field': 'row_starts',
'row_starts': [0, 1, 2]
},
# from_row_limits
{
'descr': 'from_row_limits',
'factory': RaggedTensor.from_row_limits,
'test': RaggedTensor.row_limits,
'values': {
'values': [1, 2, 3, 4, 5, 6],
'row_limits': [2, 4, 6]
},
'tensor_field': 'row_limits',
'row_limits': [3]
},
# from_uniform_row_length
{
'descr': 'from_uniform_row_length',
'factory': RaggedTensor.from_uniform_row_length,
# One cannot extract uniform_row_length or nvals, so we return
# nvals//nrows = uniform_row_length, where nvals = 3
'test': (lambda rt: 3 // (rt.shape[0])),
'values': {
'values': [1, 2, 3, 4, 5, 6],
'uniform_row_length': 2
},
'tensor_field': 'uniform_row_length',
'uniform_row_length': 3
},
])
def testFactoryTypePreference(self, descr, test, factory, values,
tensor_field, **kwargs):
# When input tensors have shape information, some of these errors will be
# detected statically.
def op_cast(k, v):
if k == tensor_field:
return constant_op.constant(v, dtype=dtypes.int32)
else:
return v
value_copy = {k: op_cast(k, v) for k, v in values.items()}
rt = factory(**value_copy)
kw_copy = {k: v for k, v in kwargs.items()}
kw_copy['values'] = rt
rt2 = factory(**kw_copy)
self.assertAllEqual(kwargs[tensor_field], test(rt2))
@parameterized.parameters([
# from_value_rowids
{
@ -1557,7 +1643,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
'row_lengths': [[1, 2], [1, 0]]
},
{
'descr': 'negative row_lengths',
'descr': 'negatve row_lengths',
'factory': RaggedTensor.from_row_lengths,
'values': [1, 2, 3, 4],
'row_lengths': [3, -1, 2]
@ -1678,18 +1764,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
self.evaluate(factory(**kwargs))
# Remove shape information (by wraping tensors in placeholders), and check
# Remove shape information (by wrapping tensors in placeholders), and check
# that we detect the errors when the graph is run.
if not context.executing_eagerly():
def wrap_arg(v):
return array_ops.placeholder_with_default(
constant_op.constant(v, dtype=dtypes.int64),
tensor_shape.TensorShape(None))
kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
with self.assertRaises(errors.InvalidArgumentError):
self.evaluate(factory(**kwargs))
#=============================================================================
# RaggedTensor Variant conversion
#=============================================================================
@ -2059,8 +2148,10 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
self.assertAllEqual(rt1, [[1, 2], [3]])
spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]),
np.array([0, 0, 2, 3])])
rt2 = spec2._from_components(
[np.array([1, 2, 3]),
np.array([0, 2, 3]),
np.array([0, 0, 2, 3])])
self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])

View File

@ -0,0 +1,843 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""An internal class for representing the partition in a ragged tensor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.ragged import segment_id_ops
# pylint: disable=protected-access
_eval_using_default_session = ops._eval_using_default_session
# pylint: enable=protected-access
#===============================================================================
# RowPartition
#===============================================================================
class RowPartition(object):
"""Represents the partition of a ragged tensor.
In particular, this provides a ragged representation to a flat list,
or a deeper ragged representation of a ragged tensor. However, it does
not store the values or the substructure: only the top-level representation
is represented.
The canonical representation of a partition is row_splits, which indicates how
the flat values are divided into rows. In particular, the values for row
`rt[i]` are stored in the slice
`rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
### Alternative Row-Partitioning Schemes
In addition to `row_splits`, row partitions provide support for five other
partitioning schemes:
* `row_lengths`: a vector with shape `[nrows]`, which specifies the length
of each row.
* `value_rowids` and `nrows`: `value_rowids` is a vector with shape
`[nvals]`, corresponding one-to-one with `values`, which specifies
each value's row index. In particular, the row `rt[row]` consists of the
values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an
integer scalar that specifies the number of rows in the
`RowPartition`. (`nrows` is used to indicate trailing empty rows.)
* `row_starts` (and nvals): a vector with shape `[nrows]`, which specifies
the start offset of each row. Equivalent to `row_splits[:-1]`.
* `row_limits`: a vector with shape `[nrows]`, which specifies the stop
offset of each row. Equivalent to `row_splits[1:]`.
* `uniform_row_length` (and nvals): A scalar tensor, specifying the length
of every row. This row-partitioning scheme may only be used if all rows
have the same length.
For examples, please see the documentation on RaggedTensor.
"""
#=============================================================================
# Constructor (private)
#=============================================================================
def __init__(self,
row_splits,
cached_row_lengths=None,
cached_value_rowids=None,
cached_nrows=None,
internal=False,
uniform_row_length=None):
"""Creates a `RowPartition` with a specified partitioning for `values`.
This constructor is private -- please use one of the following ops to
build `RowPartition`s:
* `RowPartition.from_row_lengths`
* `RowPartition.from_value_rowids`
* `RowPartition.from_row_splits`
* `RowPartition.from_row_starts`
* `RowPartition.from_row_limits`
Args:
row_splits: A 1-D integer tensor with shape `[nrows+1]`.
cached_row_lengths: A 1-D integer tensor with shape `[nrows]`
cached_value_rowids: A 1-D integer tensor with shape `[nvals]`.
cached_nrows: A 1-D integer scalar tensor.
internal: True if the constructor is being called by one of the factory
methods. If false, an exception will be raised.
uniform_row_length: A scalar tensor.
Raises:
TypeError: If a row partitioning tensor has an inappropriate dtype.
TypeError: If exactly one row partitioning argument was not specified.
ValueError: If a row partitioning tensor has an inappropriate shape.
ValueError: If multiple partitioning arguments are specified.
ValueError: If nrows is specified but value_rowids is not None.
"""
if not internal:
raise ValueError("RaggedTensor constructor is private; please use one "
"of the factory methods instead (e.g., "
"RaggedTensor.from_row_lengths())")
# Validate the arguments.
if not isinstance(row_splits, ops.Tensor):
raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
row_splits)
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("Row-partitioning argument must be int32 or int64")
# Validate shapes & dtypes.
row_splits.shape.assert_has_rank(1)
row_splits.set_shape([None])
self._row_splits = row_splits
# Store any cached tensors. These are used to avoid unnecessary
# round-trip conversions when a RaggedTensor is constructed from
# lengths or rowids, and we later want those lengths/rowids back.
for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]:
if tensor is not None:
if not isinstance(tensor, ops.Tensor):
raise TypeError("Cached value must be a Tensor or None.")
elif tensor.dtype not in (dtypes.int32, dtypes.int64):
raise TypeError("Cached value must be int32 or int64.")
self._cached_row_lengths = cached_row_lengths
self._cached_value_rowids = cached_value_rowids
self._cached_nrows = cached_nrows
if uniform_row_length is not None:
if not isinstance(uniform_row_length, ops.Tensor):
raise TypeError("uniform_row_length must be a Tensor or None.")
elif uniform_row_length.dtype not in (dtypes.int32, dtypes.int64):
raise TypeError("uniform_row_length must be int32 or int64.")
self._uniform_row_length = uniform_row_length
#=============================================================================
# Factory Methods
#=============================================================================
@classmethod
def from_value_rowids(cls,
value_rowids,
nrows=None,
name=None,
validate=True,
preferred_dtype=None):
"""Creates a `RowPartition` with rows partitioned by `value_rowids`.
The implied `RaggedTensor` corresponds with the python list defined by:
```python
result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
for row in range(nrows)]
```
Args:
value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
one-to-one with `values`, and specifies each value's row index. Must be
nonnegative, and must be sorted in ascending order.
nrows: An integer scalar specifying the number of rows. This should be
specified if the `RaggedTensor` may containing empty training rows. Must
be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
name: A name prefix for the RaggedTensor (optional).
validate: If true, then use assertions to check that the arguments form a
valid `RowPartition`.
preferred_dtype: The dtype to encode value_rowids if it doesn't already
have one. The default is tf.int64.
Returns:
A `RowPartition`.
Raises:
ValueError: If `nrows` is incompatible with `value_rowids`.
#### Example:
>>> print(RowPartition.from_value_rowids(
... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
... nrows=4))
tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64))
"""
if not isinstance(validate, bool):
raise TypeError("validate must have type bool")
with ops.name_scope(name, "RowPartitionFromValueRowIds",
[value_rowids, nrows]):
value_rowids = cls._convert_row_partition(value_rowids, "value_rowids",
preferred_dtype)
if nrows is None:
const_rowids = tensor_util.constant_value(value_rowids)
if const_rowids is None:
nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1
const_nrows = None
else:
const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
nrows = ops.convert_to_tensor(
const_nrows, value_rowids.dtype, name="nrows")
else:
nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
const_nrows = tensor_util.constant_value(nrows)
if const_nrows is not None:
if const_nrows < 0:
raise ValueError("Expected nrows >= 0; got %d" % const_nrows)
const_rowids = tensor_util.constant_value(value_rowids)
if const_rowids is not None and const_rowids.size > 0:
if not const_nrows >= const_rowids[-1] + 1:
raise ValueError(
"Expected nrows >= value_rowids[-1] + 1; got nrows=%d, "
"value_rowids[-1]=%d" % (const_nrows, const_rowids[-1]))
value_rowids.shape.assert_has_rank(1)
nrows.shape.assert_has_rank(0)
if validate:
msg = ("Arguments to from_value_rowids do not form a valid "
"RowPartition")
checks = [
check_ops.assert_rank(value_rowids, 1, message=msg),
check_ops.assert_rank(nrows, 0, message=msg),
check_ops.assert_non_negative(value_rowids[:1], message=msg),
_assert_monotonic_increasing(value_rowids, message=msg),
check_ops.assert_less(value_rowids[-1:], nrows, message=msg),
]
value_rowids = control_flow_ops.with_dependencies(checks, value_rowids)
# Convert value_rowids & nrows to row_splits.
# Note: we don't use segment_ids_to_row_splits() here because we want
# to save the intermediate value `row_lengths`, so we can cache it.
# TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
# cast.
value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
nrows_int32 = math_ops.cast(nrows, dtypes.int32)
row_lengths = math_ops.bincount(
value_rowids_int32,
minlength=nrows_int32,
maxlength=nrows_int32,
dtype=value_rowids.dtype)
row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
if const_nrows is not None:
row_lengths.set_shape([const_nrows])
row_splits.set_shape([const_nrows + 1])
return cls(
row_splits,
cached_row_lengths=row_lengths,
cached_value_rowids=value_rowids,
cached_nrows=nrows,
internal=True)
@classmethod
def from_row_splits(cls,
row_splits,
name=None,
validate=True,
preferred_dtype=None):
"""Creates a `RowPartition` with rows partitioned by `row_splits`.
A `RaggedTensor` constructed with this corresponds with the python list
defined by:
```python
result = [values[row_splits[i]:row_splits[i + 1]]
for i in range(len(row_splits) - 1)]
```
Args:
row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
empty, and must be sorted in ascending order. `row_splits[0]` must be
zero.
name: A name prefix for the RaggedTensor (optional).
validate: If true, then use assertions to check that the arguments form a
valid `RowPartition`.
preferred_dtype: If row_splits has an unspecified type, use this one. If
preferred_dtype is None, defaults to dtypes.int64.
Returns:
A `RowPartition`.
Raises:
ValueError: If `row_splits` is an empty list.
"""
if not isinstance(validate, bool):
raise TypeError("validate must have type bool")
if isinstance(row_splits, (list, tuple)) and not row_splits:
raise ValueError("row_splits tensor may not be empty.")
if isinstance(row_splits, tensor_spec.TensorSpec):
return cls(row_splits=row_splits, internal=True)
with ops.name_scope(name, "RowPartitionFromRowSplits", [row_splits]):
row_splits = cls._convert_row_partition(row_splits, "row_splits",
preferred_dtype)
row_splits.shape.assert_has_rank(1)
if validate:
msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
checks = [
check_ops.assert_rank(row_splits, 1, message=(msg + "rank")),
_assert_zero(row_splits[0], message=(msg + "zero")),
_assert_monotonic_increasing(
row_splits, message=(msg + "monotonic")),
]
row_splits = control_flow_ops.with_dependencies(checks, row_splits)
return cls(row_splits=row_splits, internal=True)
@classmethod
def from_row_lengths(cls,
row_lengths,
name=None,
validate=True,
preferred_dtype=None):
"""Creates a `RowPartition` with rows partitioned by `row_lengths`.
A `RaggedTensor` constructed with this corresponds with the python list
defined by:
```python
result = [[values.pop(0) for i in range(length)]
for length in row_lengths]
```
Args:
row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be
nonnegative.
name: A name prefix for the RowPartition (optional).
validate: If true, then use assertions to check that the arguments form a
valid `RowPartition`.
preferred_dtype: If row_lengths has an unspecified type, use this one. If
preferred_dtype is None, defaults to dtypes.int64.
Returns:
A `RowPartition`.
"""
if not isinstance(validate, bool):
raise TypeError("validate must have type bool")
with ops.name_scope(name, "RowPartitionFromRowLengths", [row_lengths]):
row_lengths = cls._convert_row_partition(row_lengths, "row_lengths",
preferred_dtype)
row_lengths.shape.assert_has_rank(1)
if validate:
msg = "Arguments to from_row_lengths do not form a valid RowPartition"
checks = [
check_ops.assert_rank(row_lengths, 1, message=msg),
check_ops.assert_non_negative(row_lengths, message=msg),
]
row_lengths = control_flow_ops.with_dependencies(checks, row_lengths)
row_limits = math_ops.cumsum(row_lengths)
row_splits = array_ops.concat([[0], row_limits], axis=0)
return cls(
row_splits=row_splits, cached_row_lengths=row_lengths, internal=True)
@classmethod
def from_row_starts(cls,
row_starts,
nvals,
name=None,
validate=True,
preferred_dtype=None):
"""Creates a `RowPartition` with rows partitioned by `row_starts`.
Equivalent to: `from_row_splits(concat([row_starts, nvals]))`.
Args:
row_starts: A 1-D integer tensor with shape `[nrows]`. Must be
nonnegative and sorted in ascending order. If `nrows>0`, then
`row_starts[0]` must be zero.
nvals: A scalar tensor indicating the number of values.
name: A name prefix for the RowPartition (optional).
validate: If true, then use assertions to check that the arguments form a
valid `RowPartition`.
preferred_dtype: If row_limits has an unspecified type, use this one. If
preferred_dtype is None, defaults to dtypes.int64.
Returns:
A `RowPartition`.
"""
if not isinstance(validate, bool):
raise TypeError("validate must have type bool")
with ops.name_scope(name, "RowPartitionFromRowStarts", [row_starts]):
row_starts = cls._convert_row_partition(row_starts, "row_starts",
preferred_dtype)
row_starts.shape.assert_has_rank(1)
nvals = math_ops.cast(nvals, row_starts.dtype)
if validate:
msg = "Arguments to from_row_starts do not form a valid RaggedTensor"
checks = [
check_ops.assert_rank(row_starts, 1, message=msg),
_assert_zero(row_starts[:1], message=msg),
_assert_monotonic_increasing(row_starts, message=msg),
check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg),
]
row_starts = control_flow_ops.with_dependencies(checks, row_starts)
row_splits = array_ops.concat([row_starts, [nvals]], axis=0)
return cls(row_splits=row_splits, internal=True)
def has_cached_value_rowids(self):
return self._cached_value_rowids is not None
@classmethod
def from_row_limits(cls,
row_limits,
name=None,
validate=True,
preferred_dtype=None):
"""Creates a `RowPartition` with rows partitioned by `row_limits`.
Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
Args:
row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in
ascending order.
name: A name prefix for the RaggedTensor (optional).
validate: If true, then use assertions to check that the arguments form a
valid `RowPartition`.
preferred_dtype: If row_limits has an unspecified type, use this one. If
preferred_dtype is None, defaults to dtypes.int64.
Returns:
A `RowPartition`.
"""
if not isinstance(validate, bool):
raise TypeError("validate must have type bool")
with ops.name_scope(name, "RowPartitionFromRowLimits", [row_limits]):
row_limits = cls._convert_row_partition(row_limits, "row_limits",
preferred_dtype)
row_limits.shape.assert_has_rank(1)
if validate:
msg = "Arguments to from_row_limits do not form a valid RaggedTensor"
checks = [
check_ops.assert_rank(row_limits, 1, message=msg),
check_ops.assert_non_negative(row_limits[:1], message=msg),
_assert_monotonic_increasing(row_limits, message=msg),
]
row_limits = control_flow_ops.with_dependencies(checks, row_limits)
zero = array_ops.zeros([1], row_limits.dtype)
row_splits = array_ops.concat([zero, row_limits], axis=0)
return cls(row_splits=row_splits, internal=True)
@classmethod
def from_uniform_row_length(cls,
nvals,
uniform_row_length,
nrows=None,
validate=True,
name=None,
preferred_dtype=None):
"""Creates a `RowPartition` with rows partitioned by `uniform_row_length`.
A `RaggedTensor` constructed with this corresponds with the python list
defined by (assuming uniform_row_length and nvals nonzero):
```python
result = [[values.pop(0) for _ in range(uniform_row_length)]
for _ in range(nrows)]
```
Note that `rt1` only contains one ragged dimension (the innermost
dimension). In contrast, if `from_row_splits` is used to construct a similar
`RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
Args:
nvals: a non-negative scalar integer tensor for the number of values.
uniform_row_length: A scalar integer tensor. Must be nonnegative. The
size of the outer axis of `values` must be evenly divisible by
`uniform_row_length`.
nrows: The number of rows in the constructed RaggedTensor. If not
specified, then it defaults to `nvals/uniform_row_length` (or `0` if
`uniform_row_length==0`). `nrows` only needs to be specified if
`uniform_row_length` might be zero. `uniform_row_length*nrows` must be
`nvals`.
validate: If true, then use assertions to check that the arguments form a
valid `RaggedTensor`.
name: A name prefix for the RaggedTensor (optional)
preferred_dtype: if uniform_row_length has no dtype, use this one.
Returns:
A `RowPartition`.
"""
if not isinstance(validate, bool):
raise TypeError("validate must have type bool")
with ops.name_scope(name, "RowPartitionFromUniformRowLength",
[uniform_row_length, nrows]):
uniform_row_length = cls._convert_row_partition(uniform_row_length,
"uniform_row_length",
preferred_dtype)
uniform_row_length.shape.assert_has_rank(0)
# Find nrows.
const_row_length = tensor_util.constant_value(uniform_row_length)
if nrows is None:
if const_row_length is None:
# Avoid division by zero if uniform_row_length==0 (and nvals==0).
rowlen_or_1 = control_flow_ops.cond(
math_ops.equal(uniform_row_length, 0),
lambda: constant_op.constant(1, uniform_row_length.dtype),
lambda: uniform_row_length)
nrows = nvals // rowlen_or_1
elif const_row_length == 0:
nrows = 0
else:
nrows = nvals // const_row_length
nrows = ops.convert_to_tensor(
nrows, uniform_row_length.dtype, name="nrows")
const_nrows = tensor_util.constant_value(nrows)
const_nvals = tensor_util.constant_value(nvals)
# Find row_splits.
if const_nrows is not None and const_row_length is not None:
row_splits = [v * const_row_length for v in range(const_nrows + 1)]
row_splits = constant_op.constant(row_splits, uniform_row_length.dtype)
else:
row_splits = math_ops.range(nrows + 1) * uniform_row_length
if validate:
checks = []
if (const_nrows is None or const_row_length is None or
const_nvals is None):
checks.append(
check_ops.assert_equal(
nrows * uniform_row_length, nvals,
("uniform_row_length", uniform_row_length, "times nrows",
nrows, "must equal nvals", nvals)))
else:
if const_nrows * const_row_length != const_nvals:
raise ValueError(
"uniform_row_length=%d times nrows=%d must equal nvals=%d" %
(const_row_length, const_nrows, const_nvals))
if uniform_row_length.shape.rank is None:
checks.append(
check_ops.assert_rank(
uniform_row_length,
0,
message="uniform_row_length must be a scalar."))
const_row_length = tensor_util.constant_value(uniform_row_length)
if const_row_length is None:
checks.append(
check_ops.assert_greater_equal(
uniform_row_length,
constant_op.constant(0, uniform_row_length.dtype),
message="uniform_row_length must be >= 0."))
else:
if const_row_length < 0:
raise ValueError("uniform_row_length must be >= 0.")
row_splits = control_flow_ops.with_dependencies(checks, row_splits)
return cls(
row_splits=row_splits,
uniform_row_length=uniform_row_length,
cached_nrows=nrows,
internal=True)
@classmethod
def _convert_row_partition(cls, partition, name, preferred_dtype):
"""Converts `partition` to Tensors.
Args:
partition: A row-partitioning tensor for the `RowPartition` being
constructed. I.e., one of: row_splits, row_lengths, row_starts,
row_limits, value_rowids.
name: The name of the row-partitioning tensor.
preferred_dtype: If partition has no dtype, give it this one. If
no dtype is specified, use dtypes.int64.
Returns:
A tensor equivalent to partition.
Raises:
ValueError: if dtype is not int32 or int64.
"""
if preferred_dtype is None:
preferred_dtype = dtypes.int64
if isinstance(partition, np.ndarray) and partition.dtype == np.int32:
partition = ops.convert_to_tensor(partition, name=name)
else:
partition = ops.convert_to_tensor(
partition, preferred_dtype=preferred_dtype, name=name)
if partition.dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("%s must have dtype int32 or int64" % name)
return partition
def with_dependencies(self, dependencies):
"""Returns a new RowPartition equal to self with control dependencies.
Specifically, self._row_splits is gated by the given control dependencies.
Used to add sanity checks to the constructors.
Args:
dependencies: a list of tensors to use as dependencies.
Returns:
A new RowPartition object.
"""
new_row_splits = control_flow_ops.with_dependencies(dependencies,
self._row_splits)
return RowPartition(
row_splits=new_row_splits,
cached_row_lengths=self._cached_row_lengths,
cached_value_rowids=self._cached_value_rowids,
cached_nrows=self._cached_nrows,
internal=True,
uniform_row_length=self._uniform_row_length)
#=============================================================================
# Accessors
#=============================================================================
@property
def dtype(self):
"""The `DType` of the row partition."""
return self._row_splits.dtype
@property
def row_splits(self):
"""The row-split indices for this row partition.
`rt.row_splits` specifies where the values for each row begin and end in
`rt.values`. In particular, the values for row `rt[i]` are stored in
the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
Returns:
A 1-D integer `Tensor` with shape `[self.nrows+1]`.
The returned tensor is non-empty, and is sorted in ascending order.
`self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
`self.values.shape[0]`.
"""
return self._row_splits
def value_rowids(self, name=None):
"""Returns the row indices for this row partition.
Returns a vector with a number of entries equal to nvals, where
the ith value in the tensor indicates the row of the ith value.
Args:
name: A name prefix for the returned tensor (optional).
Returns:
A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
The returned tensor is nonnegative, and is sorted in ascending order.
"""
if self._cached_value_rowids is not None:
return self._cached_value_rowids
with ops.name_scope(name, "RaggedValueRowIds", [self]):
return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
def nrows_as_dimension(self):
"""Returns the first dimension of the shape as a `tf.Dimension`."""
return tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
def nvals(self, out_type=None, name=None):
"""Returns the number of values in this row partition.
Specifically, should be equal to the outermost dimension of the
values associated with this row partition.
Args:
out_type: `dtype` for the returned tensor. Defaults to
`self.row_splits.dtype`.
name: A name prefix for the returned tensor (optional).
Returns:
the number of values in this row partition as a tensor scalar.
"""
if out_type is None:
return self.row_splits[-1]
else:
out_type = dtypes.as_dtype(out_type)
return math_ops.cast(self.row_splits[-1], name=name, dtype=out_type)
def nrows(self, out_type=None, name=None):
"""Returns the number of rows in this ragged tensor.
I.e., the size of the outermost dimension of the tensor.
Args:
out_type: `dtype` for the returned tensor. Defaults to
`self.row_splits.dtype`.
name: A name prefix for the returned tensor (optional).
Returns:
A scalar `Tensor` with dtype `out_type`.
"""
if out_type is None:
out_type = self._row_splits.dtype
else:
out_type = dtypes.as_dtype(out_type)
if self._cached_nrows is not None:
return math_ops.cast(self._cached_nrows, out_type)
with ops.name_scope(name, "RaggedNRows", [self]):
nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
if nsplits.value is None:
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
else:
return constant_op.constant(nsplits.value - 1, dtype=out_type)
def uniform_row_length(self):
"""Returns the uniform row length, or `None` if unspecified."""
return self._uniform_row_length
def row_starts(self, name=None):
"""Returns the start indices for rows in this row partition.
These indices specify where the values for each row begin in
`self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
Args:
name: A name prefix for the returned tensor (optional).
Returns:
A 1-D integer Tensor with shape `[nrows]`.
The returned tensor is nonnegative, and is sorted in ascending order.
"""
with ops.name_scope(name, "RaggedRowStarts", [self]):
return self.row_splits[:-1]
def row_limits(self, name=None):
"""Returns the limit indices for rows in this row partition.
These indices specify where the values for each row end in
`self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
Args:
name: A name prefix for the returned tensor (optional).
Returns:
A 1-D integer Tensor with shape `[nrows]`.
The returned tensor is nonnegative, and is sorted in ascending order.
"""
with ops.name_scope(name, "RaggedRowLimits", [self]):
return self.row_splits[1:]
def row_lengths(self, name=None):
if self._cached_row_lengths is not None:
return self._cached_row_lengths
splits = self.row_splits
with ops.name_scope(name, "RaggedRowLengths", [self]):
return splits[1:] - splits[:-1]
#=============================================================================
# Transformation
#=============================================================================
def with_row_splits_dtype(self, dtype):
"""Returns a copy of this RowPartition with the given `row_splits` dtype.
For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
nested `RaggedTensor` objects are cast to the given dtype.
Args:
dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`.
Returns:
A copy of this RaggedTensor, with the `row_splits` cast to the given
type.
"""
dtype = dtypes.as_dtype(dtype)
if dtype not in (dtypes.int32, dtypes.int64):
raise ValueError("dtype must be int32 or int64")
if self._row_splits.dtype == dtype:
return self
row_splits = math_ops.cast(self._row_splits, dtype)
cached_row_lengths = self._cached_row_lengths
if cached_row_lengths is not None:
cached_row_lengths = math_ops.cast(cached_row_lengths, dtype)
cached_value_rowids = self._cached_value_rowids
if cached_value_rowids is not None:
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
cached_nrows = self._cached_nrows
if cached_value_rowids is not None:
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
uniform_row_length = self._uniform_row_length
if uniform_row_length is not None:
uniform_row_length = math_ops.cast(uniform_row_length, dtype)
return RowPartition(
row_splits,
cached_row_lengths,
cached_value_rowids,
cached_nrows,
internal=True,
uniform_row_length=uniform_row_length)
#=============================================================================
# String Encoding
#=============================================================================
def __repr__(self):
return "tf.RowPartition(row_splits=%s)" % (self._row_splits)
#===============================================================================
# Helper Functions
#===============================================================================
def _assert_monotonic_increasing(tensor, message=None):
return check_ops.assert_non_negative(
tensor[1:] - tensor[:-1], message=message)
def _assert_zero(tensor, message=None):
return check_ops.assert_equal(
tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)

View File

@ -0,0 +1,559 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for third_party.tensorflow.python.ops.ragged_tensor."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops.ragged.row_partition import RowPartition
from tensorflow.python.platform import googletest
class _SliceBuilder(object):
"""Helper to construct arguments for __getitem__.
Usage: _SliceBuilder()[<expr>] slice_spec Python generates for <expr>.
"""
def __getitem__(self, slice_spec):
return slice_spec
SLICE_BUILDER = _SliceBuilder()
def _make_tensor_slice_spec(slice_spec, use_constant=True):
"""Wraps all integers in an extended slice spec w/ a tensor.
This function is used to help test slicing when the slice spec contains
tensors, rather than integers.
Args:
slice_spec: The extended slice spec.
use_constant: If true, then wrap each integer with a tf.constant. If false,
then wrap each integer with a tf.placeholder.
Returns:
A copy of slice_spec, but with each integer i replaced with tf.constant(i).
"""
def make_piece_scalar(piece):
if isinstance(piece, int):
scalar = constant_op.constant(piece)
if use_constant:
return scalar
else:
return array_ops.placeholder_with_default(scalar, [])
elif isinstance(piece, slice):
return slice(
make_piece_scalar(piece.start), make_piece_scalar(piece.stop),
make_piece_scalar(piece.step))
else:
return piece
if isinstance(slice_spec, tuple):
return tuple(make_piece_scalar(piece) for piece in slice_spec)
else:
return make_piece_scalar(slice_spec)
# Example 2D ragged tensor value with one ragged dimension and with scalar
# values, expressed as nested python lists and as splits+values.
EXAMPLE_RAGGED_TENSOR_2D = [[b'a', b'b'], [b'c', b'd', b'e'], [b'f'], [],
[b'g']]
EXAMPLE_RAGGED_TENSOR_2D_SPLITS = [0, 2, 5, 6, 6, 7]
EXAMPLE_RAGGED_TENSOR_2D_VALUES = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
# Example 4D ragged tensor value, with two ragged dimensions and with values
# whose shape is [2], expressed as nested python lists and as splits+values.
EXAMPLE_RAGGED_TENSOR_4D = [
[ # rt[0]
[[1, 2], [3, 4], [5, 6]], # rt[0][0]
[[7, 8], [9, 10], [11, 12]]], # rt[0][1]
[], # rt[1]
[ # rt[2]
[[13, 14], [15, 16], [17, 18]]], # rt[2][0]
[ # rt[3]
[[19, 20]]] # rt[3][0]
] # pyformat: disable
EXAMPLE_RAGGED_TENSOR_4D_SPLITS1 = [0, 2, 2, 3, 4]
EXAMPLE_RAGGED_TENSOR_4D_SPLITS2 = [0, 3, 6, 9, 10]
EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[11, 12], [13, 14], [15, 16], [17, 18],
[19, 20]]
# Example 3D ragged tensor with uniform_row_lengths.
EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]]
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3
EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9]
EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
def int32array(values):
return np.array(values, dtype=np.int32)
@test_util.run_all_in_graph_and_eager_modes
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
#=============================================================================
# RaggedTensor class docstring examples
#=============================================================================
def testClassDocStringExamples(self):
# From section: "Component Tensors"
rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
self.assertAllEqual(rt.row_splits, [0, 4, 4, 7, 8, 8])
del rt
# From section: "Alternative Row-Partitioning Schemes"
rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0])
rt3 = RowPartition.from_value_rowids(
value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8)
rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8])
for rt in (rt1, rt2, rt3, rt4, rt5):
self.assertAllEqual(rt.row_splits, [0, 4, 4, 7, 8, 8])
del rt1, rt2, rt3, rt4, rt5
# From section: "Multiple Ragged Dimensions"
inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5])
del inner_rt, outer_rt
#=============================================================================
# RaggedTensor Constructor (private)
#=============================================================================
def testRaggedTensorConstruction(self):
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rt = RowPartition(row_splits=row_splits, internal=True)
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
def testRaggedTensorConstructionErrors(self):
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
with self.assertRaisesRegexp(ValueError,
'RaggedTensor constructor is private'):
RowPartition(row_splits=row_splits)
with self.assertRaisesRegexp(TypeError,
'Row-partitioning argument must be a Tensor'):
RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=True)
with self.assertRaisesRegexp(ValueError,
r'Shape \(6, 1\) must have rank 1'):
RowPartition(
row_splits=array_ops.expand_dims(row_splits, 1), internal=True)
with self.assertRaisesRegexp(TypeError,
'Cached value must be a Tensor or None.'):
RowPartition(
row_splits=row_splits, cached_row_lengths=[2, 3, 4], internal=True)
#=============================================================================
# RaggedTensor Factory Ops
#=============================================================================
def testFromValueRowIdsWithDerivedNRows(self):
# nrows is known at graph creation time.
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
# TODO(martinz): add nrows
rt = RowPartition.from_value_rowids(value_rowids, validate=False)
self.assertEqual(rt.dtype, dtypes.int64)
rt_row_splits = rt.row_splits
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
# nrows is not known at graph creation time.
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None)
rt = RowPartition.from_value_rowids(value_rowids, validate=False)
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, 5)
def testFromValueRowIdsWithExplicitNRows(self):
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(7, dtypes.int64)
rt = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
rt_row_splits = rt.row_splits
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7, 7, 7])
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(5, dtypes.int64)
rt = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)
rt_value_rowids = rt.value_rowids()
rt_nrows = rt.nrows()
rt_row_splits = rt.row_splits
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
self.assertIs(rt_nrows, nrows) # cached_nrows
self.assertAllEqual(rt_value_rowids, value_rowids)
self.assertAllEqual(rt_nrows, nrows)
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
def testFromValueRowIdsWithEmptyValues(self):
rt = RowPartition.from_value_rowids([])
rt_nrows = rt.nrows()
self.assertEqual(rt.dtype, dtypes.int64)
self.assertEqual(rt.value_rowids().shape.as_list(), [0])
self.assertAllEqual(rt_nrows, 0)
def testFromRowSplits(self):
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rt = RowPartition.from_row_splits(row_splits, validate=False)
self.assertEqual(rt.dtype, dtypes.int64)
rt_row_splits = rt.row_splits
rt_nrows = rt.nrows()
self.assertIs(rt_row_splits, row_splits)
self.assertAllEqual(rt_nrows, 5)
def testFromRowSplitsWithDifferentSplitTypes(self):
splits1 = [0, 2, 2, 5, 6, 7]
splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
rt1 = RowPartition.from_row_splits(splits1)
rt2 = RowPartition.from_row_splits(splits2)
rt3 = RowPartition.from_row_splits(splits3)
rt4 = RowPartition.from_row_splits(splits4)
rt5 = RowPartition.from_row_splits(splits5)
self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
def testFromRowSplitsWithEmptySplits(self):
err_msg = 'row_splits tensor may not be empty'
with self.assertRaisesRegexp(ValueError, err_msg):
RowPartition.from_row_splits([], [])
def testFromRowStarts(self):
nvals = constant_op.constant(7)
row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
rt = RowPartition.from_row_starts(row_starts, nvals, validate=False)
self.assertEqual(rt.dtype, dtypes.int64)
rt_row_starts = rt.row_starts()
rt_row_splits = rt.row_splits
rt_nrows = rt.nrows()
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_starts, row_starts)
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
def testFromRowLimits(self):
row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
rt = RowPartition.from_row_limits(row_limits, validate=False)
self.assertEqual(rt.dtype, dtypes.int64)
rt_row_limits = rt.row_limits()
rt_row_splits = rt.row_splits
rt_nrows = rt.nrows()
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_limits, row_limits)
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
def testFromRowLengths(self):
row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
rt = RowPartition.from_row_lengths(row_lengths, validate=False)
self.assertEqual(rt.dtype, dtypes.int64)
rt_row_lengths = rt.row_lengths()
rt_nrows = rt.nrows()
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
self.assertAllEqual(rt_nrows, 5)
self.assertAllEqual(rt_row_lengths, row_lengths)
def testFromUniformRowLength(self):
nvals = 16
a1 = RowPartition.from_uniform_row_length(nvals, 2)
self.assertAllEqual(a1.uniform_row_length(), 2)
self.assertAllEqual(a1.nrows(), 8)
def testFromUniformRowLengthWithEmptyValues(self):
a = RowPartition.from_uniform_row_length(
nvals=0, uniform_row_length=0, nrows=10)
self.assertEqual(self.evaluate(a.nvals()), 0)
self.assertEqual(self.evaluate(a.nrows()), 10)
def testFromUniformRowLengthWithPlaceholders1(self):
nvals = array_ops.placeholder_with_default(
constant_op.constant(6, dtype=dtypes.int64), None)
rt1 = RowPartition.from_uniform_row_length(nvals, 3)
const_nvals1 = self.evaluate(rt1.nvals())
self.assertEqual(const_nvals1, 6)
def testFromUniformRowLengthWithPlaceholders2(self):
nvals = array_ops.placeholder_with_default(6, None)
ph_rowlen = array_ops.placeholder_with_default(3, None)
rt2 = RowPartition.from_uniform_row_length(nvals, ph_rowlen)
const_nvals2 = self.evaluate(rt2.nvals())
self.assertEqual(const_nvals2, 6)
def testFromValueRowIdsWithBadNRows(self):
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(5, dtypes.int64)
with self.assertRaisesRegexp(ValueError, r'Expected nrows >= 0; got -2'):
RowPartition.from_value_rowids(
value_rowids=array_ops.placeholder_with_default(value_rowids, None),
nrows=-2)
with self.assertRaisesRegexp(
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
r'value_rowids\[-1\]=4'):
RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2)
with self.assertRaisesRegexp(
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
r'value_rowids\[-1\]=4'):
RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4)
with self.assertRaisesRegexp(ValueError,
r'Shape \(7, 1\) must have rank 1'):
RowPartition.from_value_rowids(
value_rowids=array_ops.expand_dims(value_rowids, 1), nrows=nrows)
with self.assertRaisesRegexp(ValueError, r'Shape \(1,\) must have rank 0'):
RowPartition.from_value_rowids(
value_rowids=value_rowids, nrows=array_ops.expand_dims(nrows, 0))
#=============================================================================
# RowPartition.__str__
#=============================================================================
def testRowPartitionStr(self):
row_splits = [0, 2, 5, 6, 6, 7]
rt = RowPartition.from_row_splits(row_splits, validate=False)
splits_type = 'int64'
if context.executing_eagerly():
expected_repr = ('tf.RowPartition(row_splits=tf.Tensor([0 2 5 6 6 7], '
'shape=(6,), dtype=int64))')
else:
expected_repr = ('tf.RowPartition(row_splits='
'Tensor("RowPartitionFromRowSplits/row_splits:0", '
'shape=(6,), dtype={}))').format(splits_type)
self.assertEqual(repr(rt), expected_repr)
self.assertEqual(str(rt), expected_repr)
@parameterized.parameters([
# from_value_rowids
{
'descr': 'bad rank for value_rowids',
'factory': RowPartition.from_value_rowids,
'value_rowids': [[1, 2], [3, 4]],
'nrows': 10
},
{
'descr': 'bad rank for nrows',
'factory': RowPartition.from_value_rowids,
'value_rowids': [1, 2, 3, 4],
'nrows': [10]
},
{
'descr': 'negative value_rowid',
'factory': RowPartition.from_value_rowids,
'value_rowids': [-5, 2, 3, 4],
'nrows': 10
},
{
'descr': 'non-monotonic-increasing value_rowid',
'factory': RowPartition.from_value_rowids,
'value_rowids': [4, 3, 2, 1],
'nrows': 10
},
{
'descr': 'value_rowid > nrows',
'factory': RowPartition.from_value_rowids,
'value_rowids': [1, 2, 3, 4],
'nrows': 2
},
# from_row_splits
{
'descr': 'bad rank for row_splits',
'factory': RowPartition.from_row_splits,
'row_splits': [[1, 2], [3, 4]]
},
{
'descr': 'row_splits[0] != 0',
'factory': RowPartition.from_row_splits,
'row_splits': [2, 3, 4]
},
{
'descr': 'non-monotonic-increasing row_splits',
'factory': RowPartition.from_row_splits,
'row_splits': [0, 3, 2, 4]
},
# from_row_lengths
{
'descr': 'bad rank for row_lengths',
'factory': RowPartition.from_row_lengths,
'row_lengths': [[1, 2], [1, 0]]
},
{
'descr': 'negatve row_lengths',
'factory': RowPartition.from_row_lengths,
'row_lengths': [3, -1, 2]
},
# from_row_starts
{
'descr': 'bad rank for row_starts',
'factory': RowPartition.from_row_starts,
'nvals': 2,
'row_starts': [[1, 2], [3, 4]]
},
{
'descr': 'row_starts[0] != 0',
'factory': RowPartition.from_row_starts,
'nvals': 5,
'row_starts': [2, 3, 4]
},
{
'descr': 'non-monotonic-increasing row_starts',
'factory': RowPartition.from_row_starts,
'nvals': 4,
'row_starts': [0, 3, 2, 4]
},
{
'descr': 'row_starts[0] > nvals',
'factory': RowPartition.from_row_starts,
'nvals': 4,
'row_starts': [0, 2, 3, 5]
},
# from_row_limits
{
'descr': 'bad rank for row_limits',
'factory': RowPartition.from_row_limits,
'row_limits': [[1, 2], [3, 4]]
},
{
'descr': 'row_limits[0] < 0',
'factory': RowPartition.from_row_limits,
'row_limits': [-1, 3, 4]
},
{
'descr': 'non-monotonic-increasing row_limits',
'factory': RowPartition.from_row_limits,
'row_limits': [0, 3, 2, 4]
},
# from_uniform_row_length
{
'descr': 'rowlen * nrows != nvals (1)',
'factory': RowPartition.from_uniform_row_length,
'nvals': 5,
'uniform_row_length': 3
},
{
'descr': 'rowlen * nrows != nvals (2)',
'factory': RowPartition.from_uniform_row_length,
'nvals': 5,
'uniform_row_length': 6
},
{
'descr': 'rowlen * nrows != nvals (3)',
'factory': RowPartition.from_uniform_row_length,
'nvals': 6,
'uniform_row_length': 3,
'nrows': 3
},
{
'descr': 'rowlen must be a scalar',
'factory': RowPartition.from_uniform_row_length,
'nvals': 4,
'uniform_row_length': [2]
},
{
'descr': 'rowlen must be nonnegative',
'factory': RowPartition.from_uniform_row_length,
'nvals': 4,
'uniform_row_length': -1
},
])
def testFactoryValidation(self, descr, factory, **kwargs):
# When input tensors have shape information, some of these errors will be
# detected statically.
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
partition = factory(**kwargs)
self.evaluate(partition.row_splits)
# Remove shape information (by wrapping tensors in placeholders), and check
# that we detect the errors when the graph is run.
if not context.executing_eagerly():
def wrap_arg(v):
return array_ops.placeholder_with_default(
constant_op.constant(v, dtype=dtypes.int64),
tensor_shape.TensorShape(None))
kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
with self.assertRaises(errors.InvalidArgumentError):
partition = factory(**kwargs)
self.evaluate(partition.row_splits)
if __name__ == '__main__':
googletest.main()

View File

@ -19,6 +19,10 @@ tf_class {
name: "ragged_rank"
mtype: "<type \'property\'>"
}
member {
name: "row_partition"
mtype: "<type \'property\'>"
}
member {
name: "row_splits"
mtype: "<type \'property\'>"
@ -37,7 +41,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'values\', \'row_splits\', \'cached_row_lengths\', \'cached_value_rowids\', \'cached_nrows\', \'internal\', \'uniform_row_length\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'values\', \'row_partition\', \'internal\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "bounding_shape"

View File

@ -19,6 +19,10 @@ tf_class {
name: "ragged_rank"
mtype: "<type \'property\'>"
}
member {
name: "row_partition"
mtype: "<type \'property\'>"
}
member {
name: "row_splits"
mtype: "<type \'property\'>"
@ -37,7 +41,7 @@ tf_class {
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'values\', \'row_splits\', \'cached_row_lengths\', \'cached_value_rowids\', \'cached_nrows\', \'internal\', \'uniform_row_length\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
argspec: "args=[\'self\', \'values\', \'row_partition\', \'internal\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
member_method {
name: "bounding_shape"