Pull partition (row_splits, value_rowid, et cetera) of a RaggedTensor into a separate object. This allows one RowPartition to be shared by multiple RaggedTensor objects.
Copied over tests from RaggedTensorTest to RowPartitionTest. PiperOrigin-RevId: 298916432 Change-Id: Ibf05c1b5969c0f4d3d21b301b5ee99e7ca1350e1
This commit is contained in:
parent
0b0432ae2d
commit
6c26c995db
@ -269,6 +269,27 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "row_partition",
|
||||
srcs = ["row_partition.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":segment_id_ops",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "ragged_tensor",
|
||||
srcs = ["ragged_tensor.py"],
|
||||
@ -277,7 +298,7 @@ py_library(
|
||||
":ragged_config",
|
||||
":ragged_tensor_value",
|
||||
":ragged_util",
|
||||
":segment_id_ops",
|
||||
":row_partition",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:composite_tensor",
|
||||
@ -292,6 +313,7 @@ py_library(
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python:tf2",
|
||||
"//tensorflow/python:type_spec",
|
||||
"//tensorflow/python:util",
|
||||
"//third_party/py/numpy",
|
||||
@ -449,12 +471,44 @@ py_test(
|
||||
":ragged_tensor_value",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_spec",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "row_partition_test",
|
||||
size = "medium",
|
||||
timeout = "long",
|
||||
srcs = ["row_partition_test.py"],
|
||||
python_version = "PY3",
|
||||
shard_count = 4,
|
||||
srcs_version = "PY2AND3",
|
||||
tags = [
|
||||
"no_windows",
|
||||
],
|
||||
deps = [
|
||||
":ragged", # fixdeps: keep
|
||||
":row_partition",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:framework_test_lib",
|
||||
"//tensorflow/python:platform_test",
|
||||
"//tensorflow/python:tensor_shape",
|
||||
"//tensorflow/python:tensor_util",
|
||||
"//tensorflow/python/data/ops:dataset_ops",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//third_party/py/numpy",
|
||||
"@absl_py//absl/testing:parameterized",
|
||||
|
@ -342,5 +342,5 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None):
|
||||
for i in reversed(range(ragged_rank)):
|
||||
row_splits = array_ops.placeholder(dtypes.int64, [None],
|
||||
"row_splits_%d" % i)
|
||||
result = ragged_tensor.RaggedTensor(result, row_splits, internal=True)
|
||||
result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits)
|
||||
return result
|
||||
|
@ -37,22 +37,27 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase,
|
||||
(dtypes.int32, 1, [], 'ph',
|
||||
'tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), '
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
(dtypes.string, 1, [5], 'ph',
|
||||
'tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), '
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
(dtypes.float32, 2, [], 'ph',
|
||||
'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), '
|
||||
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
(dtypes.int32, 2, [3, 5], 'ph',
|
||||
'tf.RaggedTensor(values=tf.RaggedTensor('
|
||||
'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), '
|
||||
'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'),
|
||||
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64)), '
|
||||
'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", '
|
||||
'shape=(None,), dtype=int64))'),
|
||||
])
|
||||
def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name,
|
||||
expected):
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -40,6 +40,8 @@ from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.ops.ragged import ragged_tensor_value
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
|
||||
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
|
||||
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
||||
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
@ -128,8 +130,7 @@ def int32array(values):
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
parameterized.TestCase):
|
||||
class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
|
||||
|
||||
#=============================================================================
|
||||
@ -161,25 +162,22 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
outer_rt = RaggedTensor.from_row_splits(
|
||||
values=inner_rt, row_splits=[0, 3, 3, 5])
|
||||
self.assertEqual(outer_rt.ragged_rank, 2)
|
||||
self.assertAllEqual(
|
||||
outer_rt,
|
||||
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
self.assertAllEqual(outer_rt,
|
||||
[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
del inner_rt, outer_rt
|
||||
|
||||
# From section: "Multiple Ragged Dimensions"
|
||||
rt = RaggedTensor.from_nested_row_splits(
|
||||
flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
|
||||
nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
|
||||
self.assertAllEqual(
|
||||
rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
|
||||
del rt
|
||||
|
||||
# From section: "Uniform Inner Dimensions"
|
||||
rt = RaggedTensor.from_row_splits(
|
||||
values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
|
||||
rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
|
||||
self.assertEqual(rt.shape.as_list(), [2, None, 3])
|
||||
del rt
|
||||
|
||||
@ -223,42 +221,29 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
def testRaggedTensorConstruction(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
rt = RaggedTensor(values=values, row_splits=row_splits, internal=True)
|
||||
rp = RowPartition(row_splits=row_splits, internal=True)
|
||||
rt = RaggedTensor(values=values, row_partition=rp, internal=True)
|
||||
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testRaggedTensorConstructionErrors(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
rp = RowPartition(row_splits=row_splits, internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'RaggedTensor constructor is private'):
|
||||
RaggedTensor(values=values, row_splits=row_splits)
|
||||
RaggedTensor(values=values, row_partition=rp)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'values must be a Tensor or RaggedTensor'):
|
||||
RaggedTensor(values=range(7), row_splits=row_splits, internal=True)
|
||||
RaggedTensor(values=range(7), row_partition=rp, internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'Row-partitioning argument must be a Tensor'):
|
||||
RaggedTensor(values=values, row_splits=[0, 2, 2, 5, 6, 7], internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Shape \(6, 1\) must have rank 1'):
|
||||
RaggedTensor(
|
||||
values=values,
|
||||
row_splits=array_ops.expand_dims(row_splits, 1),
|
||||
internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'Cached value must be a Tensor or None.'):
|
||||
RaggedTensor(
|
||||
values=values,
|
||||
row_splits=row_splits,
|
||||
cached_row_lengths=[2, 3, 4],
|
||||
internal=True)
|
||||
'row_partition must be a RowPartition'):
|
||||
RaggedTensor(values=values, row_partition=[0, 2, 2, 5, 6, 7],
|
||||
internal=True)
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Factory Ops
|
||||
@ -282,9 +267,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
|
||||
# nrows is not known at graph creation time.
|
||||
@ -308,17 +292,16 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromValueRowIdsWithExplicitNRows(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
nrows = constant_op.constant(7, dtypes.int64)
|
||||
|
||||
rt = RaggedTensor.from_value_rowids(values, value_rowids, nrows,
|
||||
validate=False)
|
||||
rt = RaggedTensor.from_value_rowids(
|
||||
values, value_rowids, nrows, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.string)
|
||||
self.assertEqual(rt.shape.as_list(), [7, None])
|
||||
self.assertEqual(rt.ragged_rank, 1)
|
||||
@ -331,16 +314,15 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertIs(rt_nrows, nrows) # cached_nrows
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
|
||||
rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
|
||||
|
||||
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
nrows = constant_op.constant(5, dtypes.int64)
|
||||
|
||||
rt = RaggedTensor.from_value_rowids(values, value_rowids, nrows,
|
||||
validate=False)
|
||||
rt = RaggedTensor.from_value_rowids(
|
||||
values, value_rowids, nrows, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.string)
|
||||
self.assertEqual(rt.shape.as_list(), [5, None])
|
||||
self.assertEqual(rt.ragged_rank, 1)
|
||||
@ -354,9 +336,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_nrows, nrows) # cached_nrows
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, nrows)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromValueRowIdsWithEmptyValues(self):
|
||||
rt = RaggedTensor.from_value_rowids([], [])
|
||||
@ -385,9 +366,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertIs(rt_row_splits, row_splits)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowSplitsWithDifferentSplitTypes(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -428,9 +408,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_starts, row_starts)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowLimits(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -448,9 +427,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_values, values)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_limits, row_limits)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowLengths(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -469,21 +447,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_lengths, row_lengths)
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertAllEqual(rt,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
|
||||
def testFromRowLengthsInt32(self):
|
||||
rt = RaggedTensor.from_row_lengths([1, 2, 3, 4],
|
||||
constant_op.constant([1, 0, 3],
|
||||
dtype=dtypes.int32))
|
||||
rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0])
|
||||
self.assertAllEqual([2, 1, 0], rt2.row_lengths())
|
||||
|
||||
def testFromUniformRowLength(self):
|
||||
values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
|
||||
|
||||
a1 = RaggedTensor.from_uniform_row_length(values, 2)
|
||||
a2 = RaggedTensor.from_uniform_row_length(values, 2, 8)
|
||||
self.assertAllEqual(a1, [[1, 2], [3, 4], [5, 6], [7, 8],
|
||||
[9, 10], [11, 12], [13, 14], [15, 16]])
|
||||
self.assertAllEqual(
|
||||
a1,
|
||||
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])
|
||||
self.assertAllEqual(a1, a2)
|
||||
self.assertEqual(a1.shape.as_list(), [8, 2])
|
||||
self.assertEqual(a2.shape.as_list(), [8, 2])
|
||||
self.assertAllEqual(a1.uniform_row_length, 2)
|
||||
|
||||
b1 = RaggedTensor.from_uniform_row_length(a1, 2)
|
||||
b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4)
|
||||
@ -492,7 +476,6 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(b1, b2)
|
||||
self.assertEqual(b1.shape.as_list(), [4, 2, 2])
|
||||
self.assertEqual(b2.shape.as_list(), [4, 2, 2])
|
||||
self.assertAllEqual(b1.uniform_row_length, 2)
|
||||
|
||||
c1 = RaggedTensor.from_uniform_row_length(b1, 2)
|
||||
c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2)
|
||||
@ -501,13 +484,11 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(c1, c2)
|
||||
self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2])
|
||||
self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2])
|
||||
self.assertAllEqual(c1.uniform_row_length, 2)
|
||||
|
||||
def testFromUniformRowLengthWithEmptyValues(self):
|
||||
empty_values = []
|
||||
a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10)
|
||||
self.assertEqual(a.shape.as_list(), [10, 0])
|
||||
self.assertAllEqual(a.uniform_row_length, 0)
|
||||
|
||||
b = RaggedTensor.from_uniform_row_length(a, 2)
|
||||
self.assertEqual(b.shape.as_list(), [5, 2, 0])
|
||||
@ -570,8 +551,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
|
||||
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
|
||||
def testFromNestedValueRowIdsWithExplicitNRows(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -602,9 +582,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
|
||||
self.assertAllEqual(rt_nrows, nrows[0])
|
||||
self.assertAllEqual(rt_values_nrows, nrows[1])
|
||||
self.assertAllEqual(
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
|
||||
[[b'f'], [b'g'], []], [], []])
|
||||
self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
|
||||
[[b'f'], [b'g'], []], [], []])
|
||||
|
||||
def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
|
||||
values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
@ -635,8 +614,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
]
|
||||
|
||||
rt = RaggedTensor.from_nested_row_splits(flat_values, nested_row_splits,
|
||||
validate=False)
|
||||
rt = RaggedTensor.from_nested_row_splits(
|
||||
flat_values, nested_row_splits, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.string)
|
||||
self.assertEqual(rt.shape.as_list(), [4, None, None])
|
||||
self.assertEqual(rt.ragged_rank, 2)
|
||||
@ -650,8 +629,34 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertIs(rt_row_splits, nested_row_splits[0])
|
||||
self.assertIs(rt_values_row_splits, nested_row_splits[1])
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
|
||||
def testWithRowSplits(self):
|
||||
flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
|
||||
nested_row_splits = [
|
||||
constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
|
||||
constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
]
|
||||
|
||||
rt = RaggedTensor.from_nested_row_splits(
|
||||
flat_values, nested_row_splits, validate=False)
|
||||
|
||||
rt = rt.with_row_splits_dtype(dtypes.int32)
|
||||
|
||||
self.assertEqual(rt.dtype, dtypes.string)
|
||||
self.assertEqual(rt.shape.as_list(), [4, None, None])
|
||||
self.assertEqual(rt.ragged_rank, 2)
|
||||
|
||||
rt_values = rt.values
|
||||
rt_row_splits = rt.row_splits
|
||||
rt_values_values = rt_values.values
|
||||
rt_values_row_splits = rt_values.row_splits
|
||||
|
||||
self.assertAllEqual(rt_values_values, flat_values)
|
||||
self.assertAllEqual(rt_row_splits, nested_row_splits[0])
|
||||
self.assertAllEqual(rt_values_row_splits, nested_row_splits[1])
|
||||
self.assertAllEqual(
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
|
||||
def testFromNestedRowSplitsWithNonListInput(self):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
@ -747,16 +752,13 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
|
||||
|
||||
for rt in [rt1, rt2]:
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], [[10, 11]],
|
||||
[[12, 13]]])
|
||||
self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]],
|
||||
[[10, 11]], [[12, 13]]])
|
||||
self.assertAllEqual(
|
||||
rt.values,
|
||||
[[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
|
||||
self.assertEqual(rt.values.shape.dims[0].value, 7)
|
||||
self.assertAllEqual(
|
||||
rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
|
||||
self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
|
||||
self.assertAllEqual(rt.nrows(), 5)
|
||||
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
|
||||
self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
|
||||
@ -786,11 +788,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
|
||||
for rt in [rt1, rt2]:
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
self.assertAllEqual(
|
||||
rt.values,
|
||||
[[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
|
||||
self.assertEqual(rt.values.shape.dims[0].value, 5)
|
||||
self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3])
|
||||
self.assertAllEqual(rt.nrows(), 4)
|
||||
@ -798,9 +798,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3])
|
||||
self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5])
|
||||
self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2])
|
||||
self.assertAllEqual(
|
||||
rt.flat_values,
|
||||
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
|
||||
self.assertAllEqual(rt.flat_values,
|
||||
[b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
|
||||
self.assertLen(rt.nested_row_splits, 2)
|
||||
self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
|
||||
self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
|
||||
@ -1024,8 +1023,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
'slice offsets must be integers or None'),
|
||||
|
||||
# Tests for other errors
|
||||
(SLICE_BUILDER[..., 0, 0, 0], IndexError,
|
||||
'Too many indices for RaggedTensor'),
|
||||
(SLICE_BUILDER[..., 0, 0,
|
||||
0], IndexError, 'Too many indices for RaggedTensor'),
|
||||
)
|
||||
def testRaggedTensorGetItemErrorsWithRaggedRank1(self, slice_spec, expected,
|
||||
message):
|
||||
@ -1106,9 +1105,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
[[v[::-2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
|
||||
(SLICE_BUILDER[..., ::-1, :],
|
||||
[[v[::-1] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]),
|
||||
(SLICE_BUILDER[..., ::-1],
|
||||
[[[v[::-1] for v in col] for col in row]
|
||||
for row in EXAMPLE_RAGGED_TENSOR_4D]),
|
||||
(SLICE_BUILDER[..., ::-1], [[[v[::-1] for v in col] for col in row]
|
||||
for row in EXAMPLE_RAGGED_TENSOR_4D]),
|
||||
)
|
||||
def testRaggedTensorGetItemWithRaggedRank2(self, slice_spec, expected):
|
||||
"""Test that rt.__getitem__(slice_spec) == expected."""
|
||||
@ -1212,11 +1210,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis]
|
||||
|
||||
self.assertAllEqual(
|
||||
rt,
|
||||
[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
|
||||
rt, [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []])
|
||||
self.assertAllEqual(
|
||||
rt_newaxis0,
|
||||
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
|
||||
rt_newaxis0, [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]])
|
||||
self.assertAllEqual(
|
||||
rt_newaxis1,
|
||||
[[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]])
|
||||
@ -1330,9 +1326,10 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
else:
|
||||
expected_repr = (
|
||||
'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", '
|
||||
'shape=(7,), dtype=string), row_splits='
|
||||
'Tensor("RaggedFromRowSplits/row_splits:0", '
|
||||
'shape=(6,), dtype={}))').format(splits_type)
|
||||
'shape=(7,), dtype=string), '
|
||||
'row_splits=Tensor('
|
||||
'"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",'
|
||||
' shape=(6,), dtype={}))').format(splits_type)
|
||||
self.assertEqual(repr(rt), expected_repr)
|
||||
self.assertEqual(str(rt), expected_repr)
|
||||
|
||||
@ -1362,15 +1359,11 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10)
|
||||
rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
|
||||
|
||||
self.assertAllEqual(
|
||||
rt1_plus_10,
|
||||
[[11, 12], [13, 14, 15], [16], [], [17]])
|
||||
self.assertAllEqual(
|
||||
rt2_times_10,
|
||||
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
|
||||
self.assertAllEqual(
|
||||
rt1_expanded,
|
||||
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
|
||||
self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]])
|
||||
self.assertAllEqual(rt2_times_10,
|
||||
[[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
|
||||
self.assertAllEqual(rt1_expanded,
|
||||
[[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
|
||||
|
||||
#=============================================================================
|
||||
# Session.run
|
||||
@ -1465,6 +1458,99 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
ragged_math_ops.reduce_sum(a)
|
||||
self.assertLen(a.consumers(), 1)
|
||||
|
||||
@parameterized.parameters([
|
||||
{
|
||||
'descr': 'from_value_rowids',
|
||||
'factory': RaggedTensor.from_value_rowids,
|
||||
'test': RaggedTensor.value_rowids,
|
||||
'values': {
|
||||
'values': [1, 2, 3, 4, 5, 6],
|
||||
'value_rowids': [0, 0, 1, 1, 2, 2],
|
||||
},
|
||||
'tensor_field': 'value_rowids',
|
||||
'value_rowids': [0, 1, 2],
|
||||
'nrows': 10
|
||||
},
|
||||
{
|
||||
'descr': 'from_row_splits',
|
||||
'factory': RaggedTensor.from_row_splits,
|
||||
# row_splits is a property, not a function.
|
||||
'test': (lambda rt: rt.row_splits),
|
||||
'values': {
|
||||
'values': [1, 2, 3, 4, 5, 6],
|
||||
'row_splits': [0, 2, 4, 6],
|
||||
},
|
||||
'tensor_field': 'row_splits',
|
||||
'row_splits': [0, 1, 2, 3]
|
||||
},
|
||||
{
|
||||
'descr': 'from_row_lengths',
|
||||
'factory': RaggedTensor.from_row_lengths,
|
||||
'test': RaggedTensor.row_lengths,
|
||||
'values': {
|
||||
'values': [1, 2, 3, 4, 5, 6],
|
||||
'row_lengths': [2, 2, 2],
|
||||
},
|
||||
'tensor_field': 'row_lengths',
|
||||
'row_lengths': [1, 1, 1],
|
||||
},
|
||||
# from_row_starts
|
||||
{
|
||||
'descr': 'from_row_starts',
|
||||
'factory': RaggedTensor.from_row_starts,
|
||||
'test': RaggedTensor.row_starts,
|
||||
'values': {
|
||||
'values': [1, 2, 3, 4, 5, 6],
|
||||
'row_starts': [0, 2, 4]
|
||||
},
|
||||
'tensor_field': 'row_starts',
|
||||
'row_starts': [0, 1, 2]
|
||||
},
|
||||
# from_row_limits
|
||||
{
|
||||
'descr': 'from_row_limits',
|
||||
'factory': RaggedTensor.from_row_limits,
|
||||
'test': RaggedTensor.row_limits,
|
||||
'values': {
|
||||
'values': [1, 2, 3, 4, 5, 6],
|
||||
'row_limits': [2, 4, 6]
|
||||
},
|
||||
'tensor_field': 'row_limits',
|
||||
'row_limits': [3]
|
||||
},
|
||||
# from_uniform_row_length
|
||||
{
|
||||
'descr': 'from_uniform_row_length',
|
||||
'factory': RaggedTensor.from_uniform_row_length,
|
||||
# One cannot extract uniform_row_length or nvals, so we return
|
||||
# nvals//nrows = uniform_row_length, where nvals = 3
|
||||
'test': (lambda rt: 3 // (rt.shape[0])),
|
||||
'values': {
|
||||
'values': [1, 2, 3, 4, 5, 6],
|
||||
'uniform_row_length': 2
|
||||
},
|
||||
'tensor_field': 'uniform_row_length',
|
||||
'uniform_row_length': 3
|
||||
},
|
||||
])
|
||||
def testFactoryTypePreference(self, descr, test, factory, values,
|
||||
tensor_field, **kwargs):
|
||||
# When input tensors have shape information, some of these errors will be
|
||||
# detected statically.
|
||||
def op_cast(k, v):
|
||||
if k == tensor_field:
|
||||
return constant_op.constant(v, dtype=dtypes.int32)
|
||||
else:
|
||||
return v
|
||||
|
||||
value_copy = {k: op_cast(k, v) for k, v in values.items()}
|
||||
rt = factory(**value_copy)
|
||||
|
||||
kw_copy = {k: v for k, v in kwargs.items()}
|
||||
kw_copy['values'] = rt
|
||||
rt2 = factory(**kw_copy)
|
||||
self.assertAllEqual(kwargs[tensor_field], test(rt2))
|
||||
|
||||
@parameterized.parameters([
|
||||
# from_value_rowids
|
||||
{
|
||||
@ -1557,7 +1643,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
'row_lengths': [[1, 2], [1, 0]]
|
||||
},
|
||||
{
|
||||
'descr': 'negative row_lengths',
|
||||
'descr': 'negatve row_lengths',
|
||||
'factory': RaggedTensor.from_row_lengths,
|
||||
'values': [1, 2, 3, 4],
|
||||
'row_lengths': [3, -1, 2]
|
||||
@ -1678,18 +1764,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
self.evaluate(factory(**kwargs))
|
||||
|
||||
# Remove shape information (by wraping tensors in placeholders), and check
|
||||
# Remove shape information (by wrapping tensors in placeholders), and check
|
||||
# that we detect the errors when the graph is run.
|
||||
if not context.executing_eagerly():
|
||||
|
||||
def wrap_arg(v):
|
||||
return array_ops.placeholder_with_default(
|
||||
constant_op.constant(v, dtype=dtypes.int64),
|
||||
tensor_shape.TensorShape(None))
|
||||
|
||||
kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
self.evaluate(factory(**kwargs))
|
||||
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Variant conversion
|
||||
#=============================================================================
|
||||
@ -2059,8 +2148,10 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
self.assertAllEqual(rt1, [[1, 2], [3]])
|
||||
|
||||
spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
|
||||
rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]),
|
||||
np.array([0, 0, 2, 3])])
|
||||
rt2 = spec2._from_components(
|
||||
[np.array([1, 2, 3]),
|
||||
np.array([0, 2, 3]),
|
||||
np.array([0, 0, 2, 3])])
|
||||
self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
|
||||
self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])
|
||||
|
||||
|
843
tensorflow/python/ops/ragged/row_partition.py
Normal file
843
tensorflow/python/ops/ragged/row_partition.py
Normal file
@ -0,0 +1,843 @@
|
||||
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""An internal class for representing the partition in a ragged tensor."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.ragged import segment_id_ops
|
||||
|
||||
# pylint: disable=protected-access
|
||||
_eval_using_default_session = ops._eval_using_default_session
|
||||
# pylint: enable=protected-access
|
||||
|
||||
#===============================================================================
|
||||
# RowPartition
|
||||
#===============================================================================
|
||||
|
||||
|
||||
class RowPartition(object):
|
||||
"""Represents the partition of a ragged tensor.
|
||||
|
||||
In particular, this provides a ragged representation to a flat list,
|
||||
or a deeper ragged representation of a ragged tensor. However, it does
|
||||
not store the values or the substructure: only the top-level representation
|
||||
is represented.
|
||||
|
||||
The canonical representation of a partition is row_splits, which indicates how
|
||||
the flat values are divided into rows. In particular, the values for row
|
||||
`rt[i]` are stored in the slice
|
||||
`rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
|
||||
|
||||
### Alternative Row-Partitioning Schemes
|
||||
|
||||
In addition to `row_splits`, row partitions provide support for five other
|
||||
partitioning schemes:
|
||||
|
||||
* `row_lengths`: a vector with shape `[nrows]`, which specifies the length
|
||||
of each row.
|
||||
|
||||
* `value_rowids` and `nrows`: `value_rowids` is a vector with shape
|
||||
`[nvals]`, corresponding one-to-one with `values`, which specifies
|
||||
each value's row index. In particular, the row `rt[row]` consists of the
|
||||
values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an
|
||||
integer scalar that specifies the number of rows in the
|
||||
`RowPartition`. (`nrows` is used to indicate trailing empty rows.)
|
||||
|
||||
* `row_starts` (and nvals): a vector with shape `[nrows]`, which specifies
|
||||
the start offset of each row. Equivalent to `row_splits[:-1]`.
|
||||
|
||||
* `row_limits`: a vector with shape `[nrows]`, which specifies the stop
|
||||
offset of each row. Equivalent to `row_splits[1:]`.
|
||||
|
||||
* `uniform_row_length` (and nvals): A scalar tensor, specifying the length
|
||||
of every row. This row-partitioning scheme may only be used if all rows
|
||||
have the same length.
|
||||
|
||||
For examples, please see the documentation on RaggedTensor.
|
||||
"""
|
||||
|
||||
#=============================================================================
|
||||
# Constructor (private)
|
||||
#=============================================================================
|
||||
def __init__(self,
|
||||
row_splits,
|
||||
cached_row_lengths=None,
|
||||
cached_value_rowids=None,
|
||||
cached_nrows=None,
|
||||
internal=False,
|
||||
uniform_row_length=None):
|
||||
"""Creates a `RowPartition` with a specified partitioning for `values`.
|
||||
|
||||
This constructor is private -- please use one of the following ops to
|
||||
build `RowPartition`s:
|
||||
|
||||
* `RowPartition.from_row_lengths`
|
||||
* `RowPartition.from_value_rowids`
|
||||
* `RowPartition.from_row_splits`
|
||||
* `RowPartition.from_row_starts`
|
||||
* `RowPartition.from_row_limits`
|
||||
|
||||
Args:
|
||||
row_splits: A 1-D integer tensor with shape `[nrows+1]`.
|
||||
cached_row_lengths: A 1-D integer tensor with shape `[nrows]`
|
||||
cached_value_rowids: A 1-D integer tensor with shape `[nvals]`.
|
||||
cached_nrows: A 1-D integer scalar tensor.
|
||||
internal: True if the constructor is being called by one of the factory
|
||||
methods. If false, an exception will be raised.
|
||||
uniform_row_length: A scalar tensor.
|
||||
|
||||
Raises:
|
||||
TypeError: If a row partitioning tensor has an inappropriate dtype.
|
||||
TypeError: If exactly one row partitioning argument was not specified.
|
||||
ValueError: If a row partitioning tensor has an inappropriate shape.
|
||||
ValueError: If multiple partitioning arguments are specified.
|
||||
ValueError: If nrows is specified but value_rowids is not None.
|
||||
"""
|
||||
if not internal:
|
||||
raise ValueError("RaggedTensor constructor is private; please use one "
|
||||
"of the factory methods instead (e.g., "
|
||||
"RaggedTensor.from_row_lengths())")
|
||||
|
||||
# Validate the arguments.
|
||||
if not isinstance(row_splits, ops.Tensor):
|
||||
raise TypeError("Row-partitioning argument must be a Tensor, got %r" %
|
||||
row_splits)
|
||||
if row_splits.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("Row-partitioning argument must be int32 or int64")
|
||||
|
||||
# Validate shapes & dtypes.
|
||||
row_splits.shape.assert_has_rank(1)
|
||||
row_splits.set_shape([None])
|
||||
self._row_splits = row_splits
|
||||
|
||||
# Store any cached tensors. These are used to avoid unnecessary
|
||||
# round-trip conversions when a RaggedTensor is constructed from
|
||||
# lengths or rowids, and we later want those lengths/rowids back.
|
||||
for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]:
|
||||
if tensor is not None:
|
||||
if not isinstance(tensor, ops.Tensor):
|
||||
raise TypeError("Cached value must be a Tensor or None.")
|
||||
elif tensor.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise TypeError("Cached value must be int32 or int64.")
|
||||
self._cached_row_lengths = cached_row_lengths
|
||||
self._cached_value_rowids = cached_value_rowids
|
||||
self._cached_nrows = cached_nrows
|
||||
|
||||
if uniform_row_length is not None:
|
||||
if not isinstance(uniform_row_length, ops.Tensor):
|
||||
raise TypeError("uniform_row_length must be a Tensor or None.")
|
||||
elif uniform_row_length.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise TypeError("uniform_row_length must be int32 or int64.")
|
||||
self._uniform_row_length = uniform_row_length
|
||||
|
||||
#=============================================================================
|
||||
# Factory Methods
|
||||
#=============================================================================
|
||||
|
||||
@classmethod
|
||||
def from_value_rowids(cls,
|
||||
value_rowids,
|
||||
nrows=None,
|
||||
name=None,
|
||||
validate=True,
|
||||
preferred_dtype=None):
|
||||
"""Creates a `RowPartition` with rows partitioned by `value_rowids`.
|
||||
|
||||
The implied `RaggedTensor` corresponds with the python list defined by:
|
||||
|
||||
```python
|
||||
result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
|
||||
for row in range(nrows)]
|
||||
```
|
||||
|
||||
Args:
|
||||
value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
|
||||
one-to-one with `values`, and specifies each value's row index. Must be
|
||||
nonnegative, and must be sorted in ascending order.
|
||||
nrows: An integer scalar specifying the number of rows. This should be
|
||||
specified if the `RaggedTensor` may containing empty training rows. Must
|
||||
be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
|
||||
Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
validate: If true, then use assertions to check that the arguments form a
|
||||
valid `RowPartition`.
|
||||
preferred_dtype: The dtype to encode value_rowids if it doesn't already
|
||||
have one. The default is tf.int64.
|
||||
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `nrows` is incompatible with `value_rowids`.
|
||||
|
||||
#### Example:
|
||||
|
||||
>>> print(RowPartition.from_value_rowids(
|
||||
... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
|
||||
... nrows=4))
|
||||
tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64))
|
||||
"""
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RowPartitionFromValueRowIds",
|
||||
[value_rowids, nrows]):
|
||||
value_rowids = cls._convert_row_partition(value_rowids, "value_rowids",
|
||||
preferred_dtype)
|
||||
if nrows is None:
|
||||
const_rowids = tensor_util.constant_value(value_rowids)
|
||||
if const_rowids is None:
|
||||
nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1
|
||||
const_nrows = None
|
||||
else:
|
||||
const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0
|
||||
nrows = ops.convert_to_tensor(
|
||||
const_nrows, value_rowids.dtype, name="nrows")
|
||||
else:
|
||||
nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows")
|
||||
const_nrows = tensor_util.constant_value(nrows)
|
||||
if const_nrows is not None:
|
||||
if const_nrows < 0:
|
||||
raise ValueError("Expected nrows >= 0; got %d" % const_nrows)
|
||||
const_rowids = tensor_util.constant_value(value_rowids)
|
||||
if const_rowids is not None and const_rowids.size > 0:
|
||||
if not const_nrows >= const_rowids[-1] + 1:
|
||||
raise ValueError(
|
||||
"Expected nrows >= value_rowids[-1] + 1; got nrows=%d, "
|
||||
"value_rowids[-1]=%d" % (const_nrows, const_rowids[-1]))
|
||||
|
||||
value_rowids.shape.assert_has_rank(1)
|
||||
nrows.shape.assert_has_rank(0)
|
||||
|
||||
if validate:
|
||||
msg = ("Arguments to from_value_rowids do not form a valid "
|
||||
"RowPartition")
|
||||
checks = [
|
||||
check_ops.assert_rank(value_rowids, 1, message=msg),
|
||||
check_ops.assert_rank(nrows, 0, message=msg),
|
||||
check_ops.assert_non_negative(value_rowids[:1], message=msg),
|
||||
_assert_monotonic_increasing(value_rowids, message=msg),
|
||||
check_ops.assert_less(value_rowids[-1:], nrows, message=msg),
|
||||
]
|
||||
value_rowids = control_flow_ops.with_dependencies(checks, value_rowids)
|
||||
|
||||
# Convert value_rowids & nrows to row_splits.
|
||||
# Note: we don't use segment_ids_to_row_splits() here because we want
|
||||
# to save the intermediate value `row_lengths`, so we can cache it.
|
||||
# TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the
|
||||
# cast.
|
||||
value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32)
|
||||
nrows_int32 = math_ops.cast(nrows, dtypes.int32)
|
||||
row_lengths = math_ops.bincount(
|
||||
value_rowids_int32,
|
||||
minlength=nrows_int32,
|
||||
maxlength=nrows_int32,
|
||||
dtype=value_rowids.dtype)
|
||||
row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0)
|
||||
if const_nrows is not None:
|
||||
row_lengths.set_shape([const_nrows])
|
||||
row_splits.set_shape([const_nrows + 1])
|
||||
|
||||
return cls(
|
||||
row_splits,
|
||||
cached_row_lengths=row_lengths,
|
||||
cached_value_rowids=value_rowids,
|
||||
cached_nrows=nrows,
|
||||
internal=True)
|
||||
|
||||
@classmethod
|
||||
def from_row_splits(cls,
|
||||
row_splits,
|
||||
name=None,
|
||||
validate=True,
|
||||
preferred_dtype=None):
|
||||
"""Creates a `RowPartition` with rows partitioned by `row_splits`.
|
||||
|
||||
A `RaggedTensor` constructed with this corresponds with the python list
|
||||
defined by:
|
||||
|
||||
```python
|
||||
result = [values[row_splits[i]:row_splits[i + 1]]
|
||||
for i in range(len(row_splits) - 1)]
|
||||
```
|
||||
|
||||
Args:
|
||||
row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be
|
||||
empty, and must be sorted in ascending order. `row_splits[0]` must be
|
||||
zero.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
validate: If true, then use assertions to check that the arguments form a
|
||||
valid `RowPartition`.
|
||||
preferred_dtype: If row_splits has an unspecified type, use this one. If
|
||||
preferred_dtype is None, defaults to dtypes.int64.
|
||||
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `row_splits` is an empty list.
|
||||
|
||||
"""
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
if isinstance(row_splits, (list, tuple)) and not row_splits:
|
||||
raise ValueError("row_splits tensor may not be empty.")
|
||||
if isinstance(row_splits, tensor_spec.TensorSpec):
|
||||
return cls(row_splits=row_splits, internal=True)
|
||||
|
||||
with ops.name_scope(name, "RowPartitionFromRowSplits", [row_splits]):
|
||||
row_splits = cls._convert_row_partition(row_splits, "row_splits",
|
||||
preferred_dtype)
|
||||
row_splits.shape.assert_has_rank(1)
|
||||
|
||||
if validate:
|
||||
msg = "Arguments to from_row_splits do not form a valid RaggedTensor:"
|
||||
checks = [
|
||||
check_ops.assert_rank(row_splits, 1, message=(msg + "rank")),
|
||||
_assert_zero(row_splits[0], message=(msg + "zero")),
|
||||
_assert_monotonic_increasing(
|
||||
row_splits, message=(msg + "monotonic")),
|
||||
]
|
||||
row_splits = control_flow_ops.with_dependencies(checks, row_splits)
|
||||
|
||||
return cls(row_splits=row_splits, internal=True)
|
||||
|
||||
@classmethod
|
||||
def from_row_lengths(cls,
|
||||
row_lengths,
|
||||
name=None,
|
||||
validate=True,
|
||||
preferred_dtype=None):
|
||||
"""Creates a `RowPartition` with rows partitioned by `row_lengths`.
|
||||
|
||||
A `RaggedTensor` constructed with this corresponds with the python list
|
||||
defined by:
|
||||
|
||||
```python
|
||||
result = [[values.pop(0) for i in range(length)]
|
||||
for length in row_lengths]
|
||||
```
|
||||
|
||||
Args:
|
||||
row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be
|
||||
nonnegative.
|
||||
name: A name prefix for the RowPartition (optional).
|
||||
validate: If true, then use assertions to check that the arguments form a
|
||||
valid `RowPartition`.
|
||||
preferred_dtype: If row_lengths has an unspecified type, use this one. If
|
||||
preferred_dtype is None, defaults to dtypes.int64.
|
||||
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
"""
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RowPartitionFromRowLengths", [row_lengths]):
|
||||
row_lengths = cls._convert_row_partition(row_lengths, "row_lengths",
|
||||
preferred_dtype)
|
||||
row_lengths.shape.assert_has_rank(1)
|
||||
|
||||
if validate:
|
||||
msg = "Arguments to from_row_lengths do not form a valid RowPartition"
|
||||
checks = [
|
||||
check_ops.assert_rank(row_lengths, 1, message=msg),
|
||||
check_ops.assert_non_negative(row_lengths, message=msg),
|
||||
]
|
||||
row_lengths = control_flow_ops.with_dependencies(checks, row_lengths)
|
||||
|
||||
row_limits = math_ops.cumsum(row_lengths)
|
||||
row_splits = array_ops.concat([[0], row_limits], axis=0)
|
||||
return cls(
|
||||
row_splits=row_splits, cached_row_lengths=row_lengths, internal=True)
|
||||
|
||||
@classmethod
|
||||
def from_row_starts(cls,
|
||||
row_starts,
|
||||
nvals,
|
||||
name=None,
|
||||
validate=True,
|
||||
preferred_dtype=None):
|
||||
"""Creates a `RowPartition` with rows partitioned by `row_starts`.
|
||||
|
||||
Equivalent to: `from_row_splits(concat([row_starts, nvals]))`.
|
||||
|
||||
Args:
|
||||
row_starts: A 1-D integer tensor with shape `[nrows]`. Must be
|
||||
nonnegative and sorted in ascending order. If `nrows>0`, then
|
||||
`row_starts[0]` must be zero.
|
||||
nvals: A scalar tensor indicating the number of values.
|
||||
name: A name prefix for the RowPartition (optional).
|
||||
validate: If true, then use assertions to check that the arguments form a
|
||||
valid `RowPartition`.
|
||||
preferred_dtype: If row_limits has an unspecified type, use this one. If
|
||||
preferred_dtype is None, defaults to dtypes.int64.
|
||||
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
"""
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RowPartitionFromRowStarts", [row_starts]):
|
||||
row_starts = cls._convert_row_partition(row_starts, "row_starts",
|
||||
preferred_dtype)
|
||||
row_starts.shape.assert_has_rank(1)
|
||||
nvals = math_ops.cast(nvals, row_starts.dtype)
|
||||
if validate:
|
||||
msg = "Arguments to from_row_starts do not form a valid RaggedTensor"
|
||||
checks = [
|
||||
check_ops.assert_rank(row_starts, 1, message=msg),
|
||||
_assert_zero(row_starts[:1], message=msg),
|
||||
_assert_monotonic_increasing(row_starts, message=msg),
|
||||
check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg),
|
||||
]
|
||||
row_starts = control_flow_ops.with_dependencies(checks, row_starts)
|
||||
|
||||
row_splits = array_ops.concat([row_starts, [nvals]], axis=0)
|
||||
return cls(row_splits=row_splits, internal=True)
|
||||
|
||||
def has_cached_value_rowids(self):
|
||||
return self._cached_value_rowids is not None
|
||||
|
||||
@classmethod
|
||||
def from_row_limits(cls,
|
||||
row_limits,
|
||||
name=None,
|
||||
validate=True,
|
||||
preferred_dtype=None):
|
||||
"""Creates a `RowPartition` with rows partitioned by `row_limits`.
|
||||
|
||||
Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
|
||||
|
||||
Args:
|
||||
row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in
|
||||
ascending order.
|
||||
name: A name prefix for the RaggedTensor (optional).
|
||||
validate: If true, then use assertions to check that the arguments form a
|
||||
valid `RowPartition`.
|
||||
preferred_dtype: If row_limits has an unspecified type, use this one. If
|
||||
preferred_dtype is None, defaults to dtypes.int64.
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
"""
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RowPartitionFromRowLimits", [row_limits]):
|
||||
row_limits = cls._convert_row_partition(row_limits, "row_limits",
|
||||
preferred_dtype)
|
||||
row_limits.shape.assert_has_rank(1)
|
||||
|
||||
if validate:
|
||||
msg = "Arguments to from_row_limits do not form a valid RaggedTensor"
|
||||
checks = [
|
||||
check_ops.assert_rank(row_limits, 1, message=msg),
|
||||
check_ops.assert_non_negative(row_limits[:1], message=msg),
|
||||
_assert_monotonic_increasing(row_limits, message=msg),
|
||||
]
|
||||
row_limits = control_flow_ops.with_dependencies(checks, row_limits)
|
||||
|
||||
zero = array_ops.zeros([1], row_limits.dtype)
|
||||
row_splits = array_ops.concat([zero, row_limits], axis=0)
|
||||
return cls(row_splits=row_splits, internal=True)
|
||||
|
||||
@classmethod
|
||||
def from_uniform_row_length(cls,
|
||||
nvals,
|
||||
uniform_row_length,
|
||||
nrows=None,
|
||||
validate=True,
|
||||
name=None,
|
||||
preferred_dtype=None):
|
||||
"""Creates a `RowPartition` with rows partitioned by `uniform_row_length`.
|
||||
|
||||
A `RaggedTensor` constructed with this corresponds with the python list
|
||||
defined by (assuming uniform_row_length and nvals nonzero):
|
||||
|
||||
```python
|
||||
result = [[values.pop(0) for _ in range(uniform_row_length)]
|
||||
for _ in range(nrows)]
|
||||
```
|
||||
|
||||
|
||||
Note that `rt1` only contains one ragged dimension (the innermost
|
||||
dimension). In contrast, if `from_row_splits` is used to construct a similar
|
||||
`RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
|
||||
|
||||
Args:
|
||||
nvals: a non-negative scalar integer tensor for the number of values.
|
||||
uniform_row_length: A scalar integer tensor. Must be nonnegative. The
|
||||
size of the outer axis of `values` must be evenly divisible by
|
||||
`uniform_row_length`.
|
||||
nrows: The number of rows in the constructed RaggedTensor. If not
|
||||
specified, then it defaults to `nvals/uniform_row_length` (or `0` if
|
||||
`uniform_row_length==0`). `nrows` only needs to be specified if
|
||||
`uniform_row_length` might be zero. `uniform_row_length*nrows` must be
|
||||
`nvals`.
|
||||
validate: If true, then use assertions to check that the arguments form a
|
||||
valid `RaggedTensor`.
|
||||
name: A name prefix for the RaggedTensor (optional)
|
||||
preferred_dtype: if uniform_row_length has no dtype, use this one.
|
||||
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
"""
|
||||
if not isinstance(validate, bool):
|
||||
raise TypeError("validate must have type bool")
|
||||
with ops.name_scope(name, "RowPartitionFromUniformRowLength",
|
||||
[uniform_row_length, nrows]):
|
||||
uniform_row_length = cls._convert_row_partition(uniform_row_length,
|
||||
"uniform_row_length",
|
||||
preferred_dtype)
|
||||
uniform_row_length.shape.assert_has_rank(0)
|
||||
|
||||
# Find nrows.
|
||||
const_row_length = tensor_util.constant_value(uniform_row_length)
|
||||
if nrows is None:
|
||||
if const_row_length is None:
|
||||
# Avoid division by zero if uniform_row_length==0 (and nvals==0).
|
||||
rowlen_or_1 = control_flow_ops.cond(
|
||||
math_ops.equal(uniform_row_length, 0),
|
||||
lambda: constant_op.constant(1, uniform_row_length.dtype),
|
||||
lambda: uniform_row_length)
|
||||
nrows = nvals // rowlen_or_1
|
||||
elif const_row_length == 0:
|
||||
nrows = 0
|
||||
else:
|
||||
nrows = nvals // const_row_length
|
||||
nrows = ops.convert_to_tensor(
|
||||
nrows, uniform_row_length.dtype, name="nrows")
|
||||
const_nrows = tensor_util.constant_value(nrows)
|
||||
const_nvals = tensor_util.constant_value(nvals)
|
||||
|
||||
# Find row_splits.
|
||||
if const_nrows is not None and const_row_length is not None:
|
||||
row_splits = [v * const_row_length for v in range(const_nrows + 1)]
|
||||
row_splits = constant_op.constant(row_splits, uniform_row_length.dtype)
|
||||
else:
|
||||
row_splits = math_ops.range(nrows + 1) * uniform_row_length
|
||||
|
||||
if validate:
|
||||
checks = []
|
||||
|
||||
if (const_nrows is None or const_row_length is None or
|
||||
const_nvals is None):
|
||||
checks.append(
|
||||
check_ops.assert_equal(
|
||||
nrows * uniform_row_length, nvals,
|
||||
("uniform_row_length", uniform_row_length, "times nrows",
|
||||
nrows, "must equal nvals", nvals)))
|
||||
else:
|
||||
if const_nrows * const_row_length != const_nvals:
|
||||
raise ValueError(
|
||||
"uniform_row_length=%d times nrows=%d must equal nvals=%d" %
|
||||
(const_row_length, const_nrows, const_nvals))
|
||||
|
||||
if uniform_row_length.shape.rank is None:
|
||||
checks.append(
|
||||
check_ops.assert_rank(
|
||||
uniform_row_length,
|
||||
0,
|
||||
message="uniform_row_length must be a scalar."))
|
||||
|
||||
const_row_length = tensor_util.constant_value(uniform_row_length)
|
||||
if const_row_length is None:
|
||||
checks.append(
|
||||
check_ops.assert_greater_equal(
|
||||
uniform_row_length,
|
||||
constant_op.constant(0, uniform_row_length.dtype),
|
||||
message="uniform_row_length must be >= 0."))
|
||||
else:
|
||||
if const_row_length < 0:
|
||||
raise ValueError("uniform_row_length must be >= 0.")
|
||||
|
||||
row_splits = control_flow_ops.with_dependencies(checks, row_splits)
|
||||
|
||||
return cls(
|
||||
row_splits=row_splits,
|
||||
uniform_row_length=uniform_row_length,
|
||||
cached_nrows=nrows,
|
||||
internal=True)
|
||||
|
||||
@classmethod
|
||||
def _convert_row_partition(cls, partition, name, preferred_dtype):
|
||||
"""Converts `partition` to Tensors.
|
||||
|
||||
Args:
|
||||
partition: A row-partitioning tensor for the `RowPartition` being
|
||||
constructed. I.e., one of: row_splits, row_lengths, row_starts,
|
||||
row_limits, value_rowids.
|
||||
name: The name of the row-partitioning tensor.
|
||||
preferred_dtype: If partition has no dtype, give it this one. If
|
||||
no dtype is specified, use dtypes.int64.
|
||||
|
||||
Returns:
|
||||
A tensor equivalent to partition.
|
||||
|
||||
Raises:
|
||||
ValueError: if dtype is not int32 or int64.
|
||||
"""
|
||||
if preferred_dtype is None:
|
||||
preferred_dtype = dtypes.int64
|
||||
if isinstance(partition, np.ndarray) and partition.dtype == np.int32:
|
||||
partition = ops.convert_to_tensor(partition, name=name)
|
||||
else:
|
||||
partition = ops.convert_to_tensor(
|
||||
partition, preferred_dtype=preferred_dtype, name=name)
|
||||
if partition.dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("%s must have dtype int32 or int64" % name)
|
||||
|
||||
return partition
|
||||
|
||||
def with_dependencies(self, dependencies):
|
||||
"""Returns a new RowPartition equal to self with control dependencies.
|
||||
|
||||
Specifically, self._row_splits is gated by the given control dependencies.
|
||||
Used to add sanity checks to the constructors.
|
||||
|
||||
Args:
|
||||
dependencies: a list of tensors to use as dependencies.
|
||||
|
||||
Returns:
|
||||
A new RowPartition object.
|
||||
"""
|
||||
new_row_splits = control_flow_ops.with_dependencies(dependencies,
|
||||
self._row_splits)
|
||||
return RowPartition(
|
||||
row_splits=new_row_splits,
|
||||
cached_row_lengths=self._cached_row_lengths,
|
||||
cached_value_rowids=self._cached_value_rowids,
|
||||
cached_nrows=self._cached_nrows,
|
||||
internal=True,
|
||||
uniform_row_length=self._uniform_row_length)
|
||||
|
||||
#=============================================================================
|
||||
# Accessors
|
||||
#=============================================================================
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The `DType` of the row partition."""
|
||||
return self._row_splits.dtype
|
||||
|
||||
@property
|
||||
def row_splits(self):
|
||||
"""The row-split indices for this row partition.
|
||||
|
||||
`rt.row_splits` specifies where the values for each row begin and end in
|
||||
`rt.values`. In particular, the values for row `rt[i]` are stored in
|
||||
the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
|
||||
|
||||
Returns:
|
||||
A 1-D integer `Tensor` with shape `[self.nrows+1]`.
|
||||
The returned tensor is non-empty, and is sorted in ascending order.
|
||||
`self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
|
||||
`self.values.shape[0]`.
|
||||
|
||||
"""
|
||||
return self._row_splits
|
||||
|
||||
def value_rowids(self, name=None):
|
||||
"""Returns the row indices for this row partition.
|
||||
|
||||
Returns a vector with a number of entries equal to nvals, where
|
||||
the ith value in the tensor indicates the row of the ith value.
|
||||
|
||||
Args:
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
|
||||
The returned tensor is nonnegative, and is sorted in ascending order.
|
||||
|
||||
"""
|
||||
if self._cached_value_rowids is not None:
|
||||
return self._cached_value_rowids
|
||||
|
||||
with ops.name_scope(name, "RaggedValueRowIds", [self]):
|
||||
return segment_id_ops.row_splits_to_segment_ids(self.row_splits)
|
||||
|
||||
def nrows_as_dimension(self):
|
||||
"""Returns the first dimension of the shape as a `tf.Dimension`."""
|
||||
return tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1
|
||||
|
||||
def nvals(self, out_type=None, name=None):
|
||||
"""Returns the number of values in this row partition.
|
||||
|
||||
Specifically, should be equal to the outermost dimension of the
|
||||
values associated with this row partition.
|
||||
|
||||
Args:
|
||||
out_type: `dtype` for the returned tensor. Defaults to
|
||||
`self.row_splits.dtype`.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
the number of values in this row partition as a tensor scalar.
|
||||
"""
|
||||
if out_type is None:
|
||||
return self.row_splits[-1]
|
||||
else:
|
||||
out_type = dtypes.as_dtype(out_type)
|
||||
return math_ops.cast(self.row_splits[-1], name=name, dtype=out_type)
|
||||
|
||||
def nrows(self, out_type=None, name=None):
|
||||
"""Returns the number of rows in this ragged tensor.
|
||||
|
||||
I.e., the size of the outermost dimension of the tensor.
|
||||
|
||||
Args:
|
||||
out_type: `dtype` for the returned tensor. Defaults to
|
||||
`self.row_splits.dtype`.
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A scalar `Tensor` with dtype `out_type`.
|
||||
|
||||
"""
|
||||
if out_type is None:
|
||||
out_type = self._row_splits.dtype
|
||||
else:
|
||||
out_type = dtypes.as_dtype(out_type)
|
||||
if self._cached_nrows is not None:
|
||||
return math_ops.cast(self._cached_nrows, out_type)
|
||||
with ops.name_scope(name, "RaggedNRows", [self]):
|
||||
nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0)
|
||||
if nsplits.value is None:
|
||||
return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1
|
||||
else:
|
||||
return constant_op.constant(nsplits.value - 1, dtype=out_type)
|
||||
|
||||
def uniform_row_length(self):
|
||||
"""Returns the uniform row length, or `None` if unspecified."""
|
||||
return self._uniform_row_length
|
||||
|
||||
def row_starts(self, name=None):
|
||||
"""Returns the start indices for rows in this row partition.
|
||||
|
||||
These indices specify where the values for each row begin in
|
||||
`self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
|
||||
|
||||
Args:
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 1-D integer Tensor with shape `[nrows]`.
|
||||
The returned tensor is nonnegative, and is sorted in ascending order.
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedRowStarts", [self]):
|
||||
return self.row_splits[:-1]
|
||||
|
||||
def row_limits(self, name=None):
|
||||
"""Returns the limit indices for rows in this row partition.
|
||||
|
||||
These indices specify where the values for each row end in
|
||||
`self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
|
||||
|
||||
Args:
|
||||
name: A name prefix for the returned tensor (optional).
|
||||
|
||||
Returns:
|
||||
A 1-D integer Tensor with shape `[nrows]`.
|
||||
The returned tensor is nonnegative, and is sorted in ascending order.
|
||||
"""
|
||||
with ops.name_scope(name, "RaggedRowLimits", [self]):
|
||||
return self.row_splits[1:]
|
||||
|
||||
def row_lengths(self, name=None):
|
||||
if self._cached_row_lengths is not None:
|
||||
return self._cached_row_lengths
|
||||
splits = self.row_splits
|
||||
with ops.name_scope(name, "RaggedRowLengths", [self]):
|
||||
return splits[1:] - splits[:-1]
|
||||
|
||||
#=============================================================================
|
||||
# Transformation
|
||||
#=============================================================================
|
||||
|
||||
def with_row_splits_dtype(self, dtype):
|
||||
"""Returns a copy of this RowPartition with the given `row_splits` dtype.
|
||||
|
||||
For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
|
||||
nested `RaggedTensor` objects are cast to the given dtype.
|
||||
|
||||
Args:
|
||||
dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`.
|
||||
|
||||
Returns:
|
||||
A copy of this RaggedTensor, with the `row_splits` cast to the given
|
||||
type.
|
||||
"""
|
||||
dtype = dtypes.as_dtype(dtype)
|
||||
if dtype not in (dtypes.int32, dtypes.int64):
|
||||
raise ValueError("dtype must be int32 or int64")
|
||||
if self._row_splits.dtype == dtype:
|
||||
return self
|
||||
|
||||
row_splits = math_ops.cast(self._row_splits, dtype)
|
||||
|
||||
cached_row_lengths = self._cached_row_lengths
|
||||
if cached_row_lengths is not None:
|
||||
cached_row_lengths = math_ops.cast(cached_row_lengths, dtype)
|
||||
cached_value_rowids = self._cached_value_rowids
|
||||
if cached_value_rowids is not None:
|
||||
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
|
||||
cached_nrows = self._cached_nrows
|
||||
if cached_value_rowids is not None:
|
||||
cached_value_rowids = math_ops.cast(cached_value_rowids, dtype)
|
||||
uniform_row_length = self._uniform_row_length
|
||||
if uniform_row_length is not None:
|
||||
uniform_row_length = math_ops.cast(uniform_row_length, dtype)
|
||||
|
||||
return RowPartition(
|
||||
row_splits,
|
||||
cached_row_lengths,
|
||||
cached_value_rowids,
|
||||
cached_nrows,
|
||||
internal=True,
|
||||
uniform_row_length=uniform_row_length)
|
||||
|
||||
#=============================================================================
|
||||
# String Encoding
|
||||
#=============================================================================
|
||||
|
||||
def __repr__(self):
|
||||
return "tf.RowPartition(row_splits=%s)" % (self._row_splits)
|
||||
|
||||
|
||||
#===============================================================================
|
||||
# Helper Functions
|
||||
#===============================================================================
|
||||
|
||||
|
||||
def _assert_monotonic_increasing(tensor, message=None):
|
||||
return check_ops.assert_non_negative(
|
||||
tensor[1:] - tensor[:-1], message=message)
|
||||
|
||||
|
||||
def _assert_zero(tensor, message=None):
|
||||
return check_ops.assert_equal(
|
||||
tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
|
559
tensorflow/python/ops/ragged/row_partition_test.py
Normal file
559
tensorflow/python/ops/ragged/row_partition_test.py
Normal file
@ -0,0 +1,559 @@
|
||||
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for third_party.tensorflow.python.ops.ragged_tensor."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
||||
from tensorflow.python.platform import googletest
|
||||
|
||||
|
||||
class _SliceBuilder(object):
|
||||
"""Helper to construct arguments for __getitem__.
|
||||
|
||||
Usage: _SliceBuilder()[<expr>] slice_spec Python generates for <expr>.
|
||||
"""
|
||||
|
||||
def __getitem__(self, slice_spec):
|
||||
return slice_spec
|
||||
|
||||
|
||||
SLICE_BUILDER = _SliceBuilder()
|
||||
|
||||
|
||||
def _make_tensor_slice_spec(slice_spec, use_constant=True):
|
||||
"""Wraps all integers in an extended slice spec w/ a tensor.
|
||||
|
||||
This function is used to help test slicing when the slice spec contains
|
||||
tensors, rather than integers.
|
||||
|
||||
Args:
|
||||
slice_spec: The extended slice spec.
|
||||
use_constant: If true, then wrap each integer with a tf.constant. If false,
|
||||
then wrap each integer with a tf.placeholder.
|
||||
|
||||
Returns:
|
||||
A copy of slice_spec, but with each integer i replaced with tf.constant(i).
|
||||
"""
|
||||
|
||||
def make_piece_scalar(piece):
|
||||
if isinstance(piece, int):
|
||||
scalar = constant_op.constant(piece)
|
||||
if use_constant:
|
||||
return scalar
|
||||
else:
|
||||
return array_ops.placeholder_with_default(scalar, [])
|
||||
elif isinstance(piece, slice):
|
||||
return slice(
|
||||
make_piece_scalar(piece.start), make_piece_scalar(piece.stop),
|
||||
make_piece_scalar(piece.step))
|
||||
else:
|
||||
return piece
|
||||
|
||||
if isinstance(slice_spec, tuple):
|
||||
return tuple(make_piece_scalar(piece) for piece in slice_spec)
|
||||
else:
|
||||
return make_piece_scalar(slice_spec)
|
||||
|
||||
|
||||
# Example 2D ragged tensor value with one ragged dimension and with scalar
|
||||
# values, expressed as nested python lists and as splits+values.
|
||||
EXAMPLE_RAGGED_TENSOR_2D = [[b'a', b'b'], [b'c', b'd', b'e'], [b'f'], [],
|
||||
[b'g']]
|
||||
EXAMPLE_RAGGED_TENSOR_2D_SPLITS = [0, 2, 5, 6, 6, 7]
|
||||
EXAMPLE_RAGGED_TENSOR_2D_VALUES = ['a', 'b', 'c', 'd', 'e', 'f', 'g']
|
||||
|
||||
# Example 4D ragged tensor value, with two ragged dimensions and with values
|
||||
# whose shape is [2], expressed as nested python lists and as splits+values.
|
||||
EXAMPLE_RAGGED_TENSOR_4D = [
|
||||
[ # rt[0]
|
||||
[[1, 2], [3, 4], [5, 6]], # rt[0][0]
|
||||
[[7, 8], [9, 10], [11, 12]]], # rt[0][1]
|
||||
[], # rt[1]
|
||||
[ # rt[2]
|
||||
[[13, 14], [15, 16], [17, 18]]], # rt[2][0]
|
||||
[ # rt[3]
|
||||
[[19, 20]]] # rt[3][0]
|
||||
] # pyformat: disable
|
||||
EXAMPLE_RAGGED_TENSOR_4D_SPLITS1 = [0, 2, 2, 3, 4]
|
||||
EXAMPLE_RAGGED_TENSOR_4D_SPLITS2 = [0, 3, 6, 9, 10]
|
||||
EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
|
||||
[11, 12], [13, 14], [15, 16], [17, 18],
|
||||
[19, 20]]
|
||||
|
||||
# Example 3D ragged tensor with uniform_row_lengths.
|
||||
EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]]
|
||||
EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3
|
||||
EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9]
|
||||
EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||
|
||||
|
||||
def int32array(values):
|
||||
return np.array(values, dtype=np.int32)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor class docstring examples
|
||||
#=============================================================================
|
||||
|
||||
def testClassDocStringExamples(self):
|
||||
# From section: "Component Tensors"
|
||||
rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
|
||||
self.assertAllEqual(rt.row_splits, [0, 4, 4, 7, 8, 8])
|
||||
del rt
|
||||
|
||||
# From section: "Alternative Row-Partitioning Schemes"
|
||||
rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
|
||||
rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0])
|
||||
rt3 = RowPartition.from_value_rowids(
|
||||
value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
|
||||
rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8)
|
||||
rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8])
|
||||
for rt in (rt1, rt2, rt3, rt4, rt5):
|
||||
self.assertAllEqual(rt.row_splits, [0, 4, 4, 7, 8, 8])
|
||||
del rt1, rt2, rt3, rt4, rt5
|
||||
|
||||
# From section: "Multiple Ragged Dimensions"
|
||||
inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
|
||||
outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5])
|
||||
del inner_rt, outer_rt
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Constructor (private)
|
||||
#=============================================================================
|
||||
|
||||
def testRaggedTensorConstruction(self):
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
rt = RowPartition(row_splits=row_splits, internal=True)
|
||||
self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
|
||||
|
||||
def testRaggedTensorConstructionErrors(self):
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'RaggedTensor constructor is private'):
|
||||
RowPartition(row_splits=row_splits)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'Row-partitioning argument must be a Tensor'):
|
||||
RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Shape \(6, 1\) must have rank 1'):
|
||||
RowPartition(
|
||||
row_splits=array_ops.expand_dims(row_splits, 1), internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'Cached value must be a Tensor or None.'):
|
||||
RowPartition(
|
||||
row_splits=row_splits, cached_row_lengths=[2, 3, 4], internal=True)
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Factory Ops
|
||||
#=============================================================================
|
||||
|
||||
def testFromValueRowIdsWithDerivedNRows(self):
|
||||
# nrows is known at graph creation time.
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
# TODO(martinz): add nrows
|
||||
rt = RowPartition.from_value_rowids(value_rowids, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.int64)
|
||||
|
||||
rt_row_splits = rt.row_splits
|
||||
rt_value_rowids = rt.value_rowids()
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
|
||||
|
||||
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
|
||||
# nrows is not known at graph creation time.
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None)
|
||||
|
||||
rt = RowPartition.from_value_rowids(value_rowids, validate=False)
|
||||
|
||||
rt_value_rowids = rt.value_rowids()
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
|
||||
def testFromValueRowIdsWithExplicitNRows(self):
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
nrows = constant_op.constant(7, dtypes.int64)
|
||||
|
||||
rt = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)
|
||||
|
||||
rt_value_rowids = rt.value_rowids()
|
||||
rt_nrows = rt.nrows()
|
||||
rt_row_splits = rt.row_splits
|
||||
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertIs(rt_nrows, nrows) # cached_nrows
|
||||
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7, 7, 7])
|
||||
|
||||
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
nrows = constant_op.constant(5, dtypes.int64)
|
||||
|
||||
rt = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)
|
||||
|
||||
rt_value_rowids = rt.value_rowids()
|
||||
rt_nrows = rt.nrows()
|
||||
rt_row_splits = rt.row_splits
|
||||
|
||||
self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids
|
||||
self.assertIs(rt_nrows, nrows) # cached_nrows
|
||||
self.assertAllEqual(rt_value_rowids, value_rowids)
|
||||
self.assertAllEqual(rt_nrows, nrows)
|
||||
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
|
||||
|
||||
def testFromValueRowIdsWithEmptyValues(self):
|
||||
rt = RowPartition.from_value_rowids([])
|
||||
rt_nrows = rt.nrows()
|
||||
self.assertEqual(rt.dtype, dtypes.int64)
|
||||
self.assertEqual(rt.value_rowids().shape.as_list(), [0])
|
||||
self.assertAllEqual(rt_nrows, 0)
|
||||
|
||||
def testFromRowSplits(self):
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
|
||||
rt = RowPartition.from_row_splits(row_splits, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.int64)
|
||||
|
||||
rt_row_splits = rt.row_splits
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertIs(rt_row_splits, row_splits)
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
|
||||
def testFromRowSplitsWithDifferentSplitTypes(self):
|
||||
splits1 = [0, 2, 2, 5, 6, 7]
|
||||
splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
|
||||
splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
|
||||
splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
|
||||
rt1 = RowPartition.from_row_splits(splits1)
|
||||
rt2 = RowPartition.from_row_splits(splits2)
|
||||
rt3 = RowPartition.from_row_splits(splits3)
|
||||
rt4 = RowPartition.from_row_splits(splits4)
|
||||
rt5 = RowPartition.from_row_splits(splits5)
|
||||
self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
|
||||
self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
|
||||
self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
|
||||
self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
|
||||
self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
|
||||
|
||||
def testFromRowSplitsWithEmptySplits(self):
|
||||
err_msg = 'row_splits tensor may not be empty'
|
||||
with self.assertRaisesRegexp(ValueError, err_msg):
|
||||
RowPartition.from_row_splits([], [])
|
||||
|
||||
def testFromRowStarts(self):
|
||||
nvals = constant_op.constant(7)
|
||||
row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
|
||||
|
||||
rt = RowPartition.from_row_starts(row_starts, nvals, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.int64)
|
||||
|
||||
rt_row_starts = rt.row_starts()
|
||||
rt_row_splits = rt.row_splits
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_starts, row_starts)
|
||||
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
|
||||
|
||||
def testFromRowLimits(self):
|
||||
row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
|
||||
|
||||
rt = RowPartition.from_row_limits(row_limits, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.int64)
|
||||
|
||||
rt_row_limits = rt.row_limits()
|
||||
rt_row_splits = rt.row_splits
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_limits, row_limits)
|
||||
self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7])
|
||||
|
||||
def testFromRowLengths(self):
|
||||
row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
|
||||
|
||||
rt = RowPartition.from_row_lengths(row_lengths, validate=False)
|
||||
self.assertEqual(rt.dtype, dtypes.int64)
|
||||
|
||||
rt_row_lengths = rt.row_lengths()
|
||||
rt_nrows = rt.nrows()
|
||||
|
||||
self.assertIs(rt_row_lengths, row_lengths) # cached_nrows
|
||||
self.assertAllEqual(rt_nrows, 5)
|
||||
self.assertAllEqual(rt_row_lengths, row_lengths)
|
||||
|
||||
def testFromUniformRowLength(self):
|
||||
nvals = 16
|
||||
a1 = RowPartition.from_uniform_row_length(nvals, 2)
|
||||
self.assertAllEqual(a1.uniform_row_length(), 2)
|
||||
self.assertAllEqual(a1.nrows(), 8)
|
||||
|
||||
def testFromUniformRowLengthWithEmptyValues(self):
|
||||
a = RowPartition.from_uniform_row_length(
|
||||
nvals=0, uniform_row_length=0, nrows=10)
|
||||
self.assertEqual(self.evaluate(a.nvals()), 0)
|
||||
self.assertEqual(self.evaluate(a.nrows()), 10)
|
||||
|
||||
def testFromUniformRowLengthWithPlaceholders1(self):
|
||||
nvals = array_ops.placeholder_with_default(
|
||||
constant_op.constant(6, dtype=dtypes.int64), None)
|
||||
rt1 = RowPartition.from_uniform_row_length(nvals, 3)
|
||||
const_nvals1 = self.evaluate(rt1.nvals())
|
||||
self.assertEqual(const_nvals1, 6)
|
||||
|
||||
def testFromUniformRowLengthWithPlaceholders2(self):
|
||||
nvals = array_ops.placeholder_with_default(6, None)
|
||||
ph_rowlen = array_ops.placeholder_with_default(3, None)
|
||||
rt2 = RowPartition.from_uniform_row_length(nvals, ph_rowlen)
|
||||
const_nvals2 = self.evaluate(rt2.nvals())
|
||||
self.assertEqual(const_nvals2, 6)
|
||||
|
||||
def testFromValueRowIdsWithBadNRows(self):
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
nrows = constant_op.constant(5, dtypes.int64)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r'Expected nrows >= 0; got -2'):
|
||||
RowPartition.from_value_rowids(
|
||||
value_rowids=array_ops.placeholder_with_default(value_rowids, None),
|
||||
nrows=-2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
|
||||
r'value_rowids\[-1\]=4'):
|
||||
RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
|
||||
r'value_rowids\[-1\]=4'):
|
||||
RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Shape \(7, 1\) must have rank 1'):
|
||||
RowPartition.from_value_rowids(
|
||||
value_rowids=array_ops.expand_dims(value_rowids, 1), nrows=nrows)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r'Shape \(1,\) must have rank 0'):
|
||||
RowPartition.from_value_rowids(
|
||||
value_rowids=value_rowids, nrows=array_ops.expand_dims(nrows, 0))
|
||||
|
||||
#=============================================================================
|
||||
# RowPartition.__str__
|
||||
#=============================================================================
|
||||
def testRowPartitionStr(self):
|
||||
row_splits = [0, 2, 5, 6, 6, 7]
|
||||
rt = RowPartition.from_row_splits(row_splits, validate=False)
|
||||
splits_type = 'int64'
|
||||
if context.executing_eagerly():
|
||||
expected_repr = ('tf.RowPartition(row_splits=tf.Tensor([0 2 5 6 6 7], '
|
||||
'shape=(6,), dtype=int64))')
|
||||
else:
|
||||
expected_repr = ('tf.RowPartition(row_splits='
|
||||
'Tensor("RowPartitionFromRowSplits/row_splits:0", '
|
||||
'shape=(6,), dtype={}))').format(splits_type)
|
||||
self.assertEqual(repr(rt), expected_repr)
|
||||
self.assertEqual(str(rt), expected_repr)
|
||||
|
||||
@parameterized.parameters([
|
||||
# from_value_rowids
|
||||
{
|
||||
'descr': 'bad rank for value_rowids',
|
||||
'factory': RowPartition.from_value_rowids,
|
||||
'value_rowids': [[1, 2], [3, 4]],
|
||||
'nrows': 10
|
||||
},
|
||||
{
|
||||
'descr': 'bad rank for nrows',
|
||||
'factory': RowPartition.from_value_rowids,
|
||||
'value_rowids': [1, 2, 3, 4],
|
||||
'nrows': [10]
|
||||
},
|
||||
{
|
||||
'descr': 'negative value_rowid',
|
||||
'factory': RowPartition.from_value_rowids,
|
||||
'value_rowids': [-5, 2, 3, 4],
|
||||
'nrows': 10
|
||||
},
|
||||
{
|
||||
'descr': 'non-monotonic-increasing value_rowid',
|
||||
'factory': RowPartition.from_value_rowids,
|
||||
'value_rowids': [4, 3, 2, 1],
|
||||
'nrows': 10
|
||||
},
|
||||
{
|
||||
'descr': 'value_rowid > nrows',
|
||||
'factory': RowPartition.from_value_rowids,
|
||||
'value_rowids': [1, 2, 3, 4],
|
||||
'nrows': 2
|
||||
},
|
||||
|
||||
# from_row_splits
|
||||
{
|
||||
'descr': 'bad rank for row_splits',
|
||||
'factory': RowPartition.from_row_splits,
|
||||
'row_splits': [[1, 2], [3, 4]]
|
||||
},
|
||||
{
|
||||
'descr': 'row_splits[0] != 0',
|
||||
'factory': RowPartition.from_row_splits,
|
||||
'row_splits': [2, 3, 4]
|
||||
},
|
||||
{
|
||||
'descr': 'non-monotonic-increasing row_splits',
|
||||
'factory': RowPartition.from_row_splits,
|
||||
'row_splits': [0, 3, 2, 4]
|
||||
},
|
||||
|
||||
# from_row_lengths
|
||||
{
|
||||
'descr': 'bad rank for row_lengths',
|
||||
'factory': RowPartition.from_row_lengths,
|
||||
'row_lengths': [[1, 2], [1, 0]]
|
||||
},
|
||||
{
|
||||
'descr': 'negatve row_lengths',
|
||||
'factory': RowPartition.from_row_lengths,
|
||||
'row_lengths': [3, -1, 2]
|
||||
},
|
||||
|
||||
# from_row_starts
|
||||
{
|
||||
'descr': 'bad rank for row_starts',
|
||||
'factory': RowPartition.from_row_starts,
|
||||
'nvals': 2,
|
||||
'row_starts': [[1, 2], [3, 4]]
|
||||
},
|
||||
{
|
||||
'descr': 'row_starts[0] != 0',
|
||||
'factory': RowPartition.from_row_starts,
|
||||
'nvals': 5,
|
||||
'row_starts': [2, 3, 4]
|
||||
},
|
||||
{
|
||||
'descr': 'non-monotonic-increasing row_starts',
|
||||
'factory': RowPartition.from_row_starts,
|
||||
'nvals': 4,
|
||||
'row_starts': [0, 3, 2, 4]
|
||||
},
|
||||
{
|
||||
'descr': 'row_starts[0] > nvals',
|
||||
'factory': RowPartition.from_row_starts,
|
||||
'nvals': 4,
|
||||
'row_starts': [0, 2, 3, 5]
|
||||
},
|
||||
|
||||
# from_row_limits
|
||||
{
|
||||
'descr': 'bad rank for row_limits',
|
||||
'factory': RowPartition.from_row_limits,
|
||||
'row_limits': [[1, 2], [3, 4]]
|
||||
},
|
||||
{
|
||||
'descr': 'row_limits[0] < 0',
|
||||
'factory': RowPartition.from_row_limits,
|
||||
'row_limits': [-1, 3, 4]
|
||||
},
|
||||
{
|
||||
'descr': 'non-monotonic-increasing row_limits',
|
||||
'factory': RowPartition.from_row_limits,
|
||||
'row_limits': [0, 3, 2, 4]
|
||||
},
|
||||
|
||||
# from_uniform_row_length
|
||||
{
|
||||
'descr': 'rowlen * nrows != nvals (1)',
|
||||
'factory': RowPartition.from_uniform_row_length,
|
||||
'nvals': 5,
|
||||
'uniform_row_length': 3
|
||||
},
|
||||
{
|
||||
'descr': 'rowlen * nrows != nvals (2)',
|
||||
'factory': RowPartition.from_uniform_row_length,
|
||||
'nvals': 5,
|
||||
'uniform_row_length': 6
|
||||
},
|
||||
{
|
||||
'descr': 'rowlen * nrows != nvals (3)',
|
||||
'factory': RowPartition.from_uniform_row_length,
|
||||
'nvals': 6,
|
||||
'uniform_row_length': 3,
|
||||
'nrows': 3
|
||||
},
|
||||
{
|
||||
'descr': 'rowlen must be a scalar',
|
||||
'factory': RowPartition.from_uniform_row_length,
|
||||
'nvals': 4,
|
||||
'uniform_row_length': [2]
|
||||
},
|
||||
{
|
||||
'descr': 'rowlen must be nonnegative',
|
||||
'factory': RowPartition.from_uniform_row_length,
|
||||
'nvals': 4,
|
||||
'uniform_row_length': -1
|
||||
},
|
||||
])
|
||||
def testFactoryValidation(self, descr, factory, **kwargs):
|
||||
# When input tensors have shape information, some of these errors will be
|
||||
# detected statically.
|
||||
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
||||
partition = factory(**kwargs)
|
||||
self.evaluate(partition.row_splits)
|
||||
|
||||
# Remove shape information (by wrapping tensors in placeholders), and check
|
||||
# that we detect the errors when the graph is run.
|
||||
if not context.executing_eagerly():
|
||||
|
||||
def wrap_arg(v):
|
||||
return array_ops.placeholder_with_default(
|
||||
constant_op.constant(v, dtype=dtypes.int64),
|
||||
tensor_shape.TensorShape(None))
|
||||
|
||||
kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
|
||||
|
||||
with self.assertRaises(errors.InvalidArgumentError):
|
||||
partition = factory(**kwargs)
|
||||
self.evaluate(partition.row_splits)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
googletest.main()
|
@ -19,6 +19,10 @@ tf_class {
|
||||
name: "ragged_rank"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "row_partition"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "row_splits"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -37,7 +41,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'values\', \'row_splits\', \'cached_row_lengths\', \'cached_value_rowids\', \'cached_nrows\', \'internal\', \'uniform_row_length\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'values\', \'row_partition\', \'internal\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "bounding_shape"
|
||||
|
@ -19,6 +19,10 @@ tf_class {
|
||||
name: "ragged_rank"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "row_partition"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "row_splits"
|
||||
mtype: "<type \'property\'>"
|
||||
@ -37,7 +41,7 @@ tf_class {
|
||||
}
|
||||
member_method {
|
||||
name: "__init__"
|
||||
argspec: "args=[\'self\', \'values\', \'row_splits\', \'cached_row_lengths\', \'cached_value_rowids\', \'cached_nrows\', \'internal\', \'uniform_row_length\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
|
||||
argspec: "args=[\'self\', \'values\', \'row_partition\', \'internal\'], varargs=None, keywords=None, defaults=[\'False\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "bounding_shape"
|
||||
|
Loading…
Reference in New Issue
Block a user