- assertEquals -> assertEqual - assertRaisesRegexp -> assertRegexpMatches - assertRegexpMatches -> assertRegex PiperOrigin-RevId: 319118081 Change-Id: Ieb457128522920ab55d6b69a7f244ab798a7d689
895 lines
36 KiB
Python
895 lines
36 KiB
Python
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for tf.ragged.RowPartition."""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
from tensorflow.python.eager import context
|
|
from tensorflow.python.framework import constant_op
|
|
from tensorflow.python.framework import dtypes
|
|
from tensorflow.python.framework import errors
|
|
from tensorflow.python.framework import tensor_shape
|
|
from tensorflow.python.framework import tensor_spec
|
|
from tensorflow.python.framework import test_util
|
|
from tensorflow.python.ops import array_ops
|
|
from tensorflow.python.ops.ragged import row_partition
|
|
from tensorflow.python.ops.ragged.row_partition import RowPartition
|
|
from tensorflow.python.ops.ragged.row_partition import RowPartitionSpec
|
|
from tensorflow.python.platform import googletest
|
|
|
|
|
|
@test_util.run_all_in_graph_and_eager_modes
|
|
class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
|
#=============================================================================
|
|
# RaggedTensor class docstring examples
|
|
#=============================================================================
|
|
|
|
def testClassDocStringExamples(self):
|
|
# From section: "Component Tensors"
|
|
rp = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
|
|
self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
|
|
del rp
|
|
|
|
# From section: "Alternative Row-Partitioning Schemes"
|
|
rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
|
|
rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0])
|
|
rt3 = RowPartition.from_value_rowids(
|
|
value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
|
|
rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8)
|
|
rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8])
|
|
for rp in (rt1, rt2, rt3, rt4, rt5):
|
|
self.assertAllEqual(rp.row_splits(), [0, 4, 4, 7, 8, 8])
|
|
del rt1, rt2, rt3, rt4, rt5
|
|
|
|
# From section: "Multiple Ragged Dimensions"
|
|
inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8])
|
|
outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5])
|
|
del inner_rt, outer_rt
|
|
|
|
#=============================================================================
|
|
# RaggedTensor Constructor (private)
|
|
#=============================================================================
|
|
|
|
def testRaggedTensorConstruction(self):
|
|
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
|
rp = RowPartition(
|
|
row_splits=row_splits,
|
|
internal=row_partition._row_partition_factory_key)
|
|
self.assertAllEqual(rp.row_splits(), [0, 2, 2, 5, 6, 7])
|
|
|
|
def testRaggedTensorConstructionErrors(self):
|
|
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
|
|
|
with self.assertRaisesRegex(ValueError,
|
|
'RaggedTensor constructor is private'):
|
|
RowPartition(row_splits=row_splits)
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
'Row-partitioning argument must be a Tensor'):
|
|
RowPartition(
|
|
row_splits=[0, 2, 2, 5, 6, 7],
|
|
internal=row_partition._row_partition_factory_key)
|
|
|
|
with self.assertRaisesRegex(ValueError, r'Shape \(6, 1\) must have rank 1'):
|
|
RowPartition(
|
|
row_splits=array_ops.expand_dims(row_splits, 1),
|
|
internal=row_partition._row_partition_factory_key)
|
|
|
|
with self.assertRaisesRegex(TypeError,
|
|
'Cached value must be a Tensor or None.'):
|
|
RowPartition(
|
|
row_splits=row_splits,
|
|
row_lengths=[2, 3, 4],
|
|
internal=row_partition._row_partition_factory_key)
|
|
|
|
#=============================================================================
|
|
# RaggedTensor Factory Ops
|
|
#=============================================================================
|
|
|
|
def testFromValueRowIdsWithDerivedNRows(self):
|
|
# nrows is known at graph creation time.
|
|
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
|
# TODO(martinz): add nrows
|
|
rp = RowPartition.from_value_rowids(value_rowids, validate=False)
|
|
self.assertEqual(rp.dtype, dtypes.int64)
|
|
|
|
rp_row_splits = rp.row_splits()
|
|
rp_value_rowids = rp.value_rowids()
|
|
rp_nrows = rp.nrows()
|
|
|
|
self.assertIs(rp_value_rowids, value_rowids) # value_rowids
|
|
self.assertAllEqual(rp_value_rowids, value_rowids)
|
|
self.assertAllEqual(rp_nrows, 5)
|
|
self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])
|
|
|
|
def testFromValueRowIdsWithDerivedNRowsDynamic(self):
|
|
# nrows is not known at graph creation time.
|
|
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
|
value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None)
|
|
|
|
rp = RowPartition.from_value_rowids(value_rowids, validate=False)
|
|
|
|
rp_value_rowids = rp.value_rowids()
|
|
rp_nrows = rp.nrows()
|
|
|
|
self.assertIs(rp_value_rowids, value_rowids) # value_rowids
|
|
self.assertAllEqual(rp_value_rowids, value_rowids)
|
|
self.assertAllEqual(rp_nrows, 5)
|
|
|
|
def testFromValueRowIdsWithExplicitNRows(self):
|
|
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
|
nrows = constant_op.constant(7, dtypes.int64)
|
|
|
|
rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)
|
|
|
|
rp_value_rowids = rp.value_rowids()
|
|
rp_nrows = rp.nrows()
|
|
rp_row_splits = rp.row_splits()
|
|
|
|
self.assertIs(rp_value_rowids, value_rowids) # value_rowids
|
|
self.assertIs(rp_nrows, nrows) # nrows
|
|
self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7, 7, 7])
|
|
|
|
def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
|
|
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
|
nrows = constant_op.constant(5, dtypes.int64)
|
|
|
|
rp = RowPartition.from_value_rowids(value_rowids, nrows, validate=False)
|
|
|
|
rp_value_rowids = rp.value_rowids()
|
|
rp_nrows = rp.nrows()
|
|
rp_row_splits = rp.row_splits()
|
|
|
|
self.assertIs(rp_value_rowids, value_rowids) # value_rowids
|
|
self.assertIs(rp_nrows, nrows) # nrows
|
|
self.assertAllEqual(rp_value_rowids, value_rowids)
|
|
self.assertAllEqual(rp_nrows, nrows)
|
|
self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])
|
|
|
|
def testFromValueRowIdsWithEmptyValues(self):
|
|
rp = RowPartition.from_value_rowids([])
|
|
rp_nrows = rp.nrows()
|
|
self.assertEqual(rp.dtype, dtypes.int64)
|
|
self.assertEqual(rp.value_rowids().shape.as_list(), [0])
|
|
self.assertAllEqual(rp_nrows, 0)
|
|
|
|
def testFromRowSplits(self):
|
|
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
|
|
|
rp = RowPartition.from_row_splits(row_splits, validate=False)
|
|
self.assertEqual(rp.dtype, dtypes.int64)
|
|
|
|
rp_row_splits = rp.row_splits()
|
|
rp_nrows = rp.nrows()
|
|
|
|
self.assertIs(rp_row_splits, row_splits)
|
|
self.assertAllEqual(rp_nrows, 5)
|
|
|
|
def testFromRowSplitsWithDifferentSplitTypes(self):
|
|
splits1 = [0, 2, 2, 5, 6, 7]
|
|
splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
|
|
splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
|
|
splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
|
splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
|
|
rt1 = RowPartition.from_row_splits(splits1)
|
|
rt2 = RowPartition.from_row_splits(splits2)
|
|
rt3 = RowPartition.from_row_splits(splits3)
|
|
rt4 = RowPartition.from_row_splits(splits4)
|
|
rt5 = RowPartition.from_row_splits(splits5)
|
|
self.assertEqual(rt1.row_splits().dtype, dtypes.int64)
|
|
self.assertEqual(rt2.row_splits().dtype, dtypes.int64)
|
|
self.assertEqual(rt3.row_splits().dtype, dtypes.int32)
|
|
self.assertEqual(rt4.row_splits().dtype, dtypes.int64)
|
|
self.assertEqual(rt5.row_splits().dtype, dtypes.int32)
|
|
|
|
def testFromRowSplitsWithEmptySplits(self):
|
|
err_msg = 'row_splits tensor may not be empty'
|
|
with self.assertRaisesRegex(ValueError, err_msg):
|
|
RowPartition.from_row_splits([])
|
|
|
|
def testFromRowStarts(self):
|
|
nvals = constant_op.constant(7)
|
|
row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
|
|
|
|
rp = RowPartition.from_row_starts(row_starts, nvals, validate=False)
|
|
self.assertEqual(rp.dtype, dtypes.int64)
|
|
|
|
rp_row_starts = rp.row_starts()
|
|
rp_row_splits = rp.row_splits()
|
|
rp_nrows = rp.nrows()
|
|
|
|
self.assertAllEqual(rp_nrows, 5)
|
|
self.assertAllEqual(rp_row_starts, row_starts)
|
|
self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])
|
|
|
|
def testFromRowLimits(self):
|
|
row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
|
|
|
|
rp = RowPartition.from_row_limits(row_limits, validate=False)
|
|
self.assertEqual(rp.dtype, dtypes.int64)
|
|
|
|
rp_row_limits = rp.row_limits()
|
|
rp_row_splits = rp.row_splits()
|
|
rp_nrows = rp.nrows()
|
|
|
|
self.assertAllEqual(rp_nrows, 5)
|
|
self.assertAllEqual(rp_row_limits, row_limits)
|
|
self.assertAllEqual(rp_row_splits, [0, 2, 2, 5, 6, 7])
|
|
|
|
def testFromRowLengths(self):
|
|
row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
|
|
|
|
rp = RowPartition.from_row_lengths(row_lengths, validate=False)
|
|
self.assertEqual(rp.dtype, dtypes.int64)
|
|
|
|
rp_row_lengths = rp.row_lengths()
|
|
rp_nrows = rp.nrows()
|
|
|
|
self.assertIs(rp_row_lengths, row_lengths) # nrows
|
|
self.assertAllEqual(rp_nrows, 5)
|
|
self.assertAllEqual(rp_row_lengths, row_lengths)
|
|
|
|
def testFromUniformRowLength(self):
|
|
nvals = 16
|
|
a1 = RowPartition.from_uniform_row_length(
|
|
nvals=nvals, uniform_row_length=2)
|
|
self.assertAllEqual(a1.uniform_row_length(), 2)
|
|
self.assertAllEqual(a1.nrows(), 8)
|
|
|
|
def testFromUniformRowLengthWithEmptyValues(self):
|
|
a = RowPartition.from_uniform_row_length(
|
|
nvals=0, uniform_row_length=0, nrows=10)
|
|
self.assertEqual(self.evaluate(a.nvals()), 0)
|
|
self.assertEqual(self.evaluate(a.nrows()), 10)
|
|
|
|
def testFromUniformRowLengthWithPlaceholders1(self):
|
|
nvals = array_ops.placeholder_with_default(
|
|
constant_op.constant(6, dtype=dtypes.int64), None)
|
|
rt1 = RowPartition.from_uniform_row_length(
|
|
nvals=nvals, uniform_row_length=3)
|
|
const_nvals1 = self.evaluate(rt1.nvals())
|
|
self.assertEqual(const_nvals1, 6)
|
|
|
|
def testFromUniformRowLengthWithPlaceholders2(self):
|
|
nvals = array_ops.placeholder_with_default(6, None)
|
|
ph_rowlen = array_ops.placeholder_with_default(3, None)
|
|
rt2 = RowPartition.from_uniform_row_length(
|
|
nvals=nvals, uniform_row_length=ph_rowlen)
|
|
const_nvals2 = self.evaluate(rt2.nvals())
|
|
self.assertEqual(const_nvals2, 6)
|
|
|
|
def testFromValueRowIdsWithBadNRows(self):
|
|
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
|
nrows = constant_op.constant(5, dtypes.int64)
|
|
|
|
with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'):
|
|
RowPartition.from_value_rowids(
|
|
value_rowids=array_ops.placeholder_with_default(value_rowids, None),
|
|
nrows=-2)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
|
|
r'value_rowids\[-1\]=4'):
|
|
RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
|
|
r'value_rowids\[-1\]=4'):
|
|
RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4)
|
|
|
|
with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'):
|
|
RowPartition.from_value_rowids(
|
|
value_rowids=array_ops.expand_dims(value_rowids, 1), nrows=nrows)
|
|
|
|
with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'):
|
|
RowPartition.from_value_rowids(
|
|
value_rowids=value_rowids, nrows=array_ops.expand_dims(nrows, 0))
|
|
|
|
#=============================================================================
|
|
# RowPartition.__str__
|
|
#=============================================================================
|
|
def testRowPartitionStr(self):
|
|
row_splits = [0, 2, 5, 6, 6, 7]
|
|
rp = RowPartition.from_row_splits(row_splits, validate=False)
|
|
splits_type = 'int64'
|
|
if context.executing_eagerly():
|
|
expected_repr = ('tf.RowPartition(row_splits=tf.Tensor([0 2 5 6 6 7], '
|
|
'shape=(6,), dtype=int64))')
|
|
else:
|
|
expected_repr = ('tf.RowPartition(row_splits='
|
|
'Tensor("RowPartitionFromRowSplits/row_splits:0", '
|
|
'shape=(6,), dtype={}))').format(splits_type)
|
|
self.assertEqual(repr(rp), expected_repr)
|
|
self.assertEqual(str(rp), expected_repr)
|
|
|
|
@parameterized.parameters([
|
|
# from_value_rowids
|
|
{
|
|
'descr': 'bad rank for value_rowids',
|
|
'factory': RowPartition.from_value_rowids,
|
|
'value_rowids': [[1, 2], [3, 4]],
|
|
'nrows': 10
|
|
},
|
|
{
|
|
'descr': 'bad rank for nrows',
|
|
'factory': RowPartition.from_value_rowids,
|
|
'value_rowids': [1, 2, 3, 4],
|
|
'nrows': [10]
|
|
},
|
|
{
|
|
'descr': 'negative value_rowid',
|
|
'factory': RowPartition.from_value_rowids,
|
|
'value_rowids': [-5, 2, 3, 4],
|
|
'nrows': 10
|
|
},
|
|
{
|
|
'descr': 'non-monotonic-increasing value_rowid',
|
|
'factory': RowPartition.from_value_rowids,
|
|
'value_rowids': [4, 3, 2, 1],
|
|
'nrows': 10
|
|
},
|
|
{
|
|
'descr': 'value_rowid > nrows',
|
|
'factory': RowPartition.from_value_rowids,
|
|
'value_rowids': [1, 2, 3, 4],
|
|
'nrows': 2
|
|
},
|
|
|
|
# from_row_splits
|
|
{
|
|
'descr': 'bad rank for row_splits',
|
|
'factory': RowPartition.from_row_splits,
|
|
'row_splits': [[1, 2], [3, 4]]
|
|
},
|
|
{
|
|
'descr': 'row_splits[0] != 0',
|
|
'factory': RowPartition.from_row_splits,
|
|
'row_splits': [2, 3, 4]
|
|
},
|
|
{
|
|
'descr': 'non-monotonic-increasing row_splits',
|
|
'factory': RowPartition.from_row_splits,
|
|
'row_splits': [0, 3, 2, 4]
|
|
},
|
|
|
|
# from_row_lengths
|
|
{
|
|
'descr': 'bad rank for row_lengths',
|
|
'factory': RowPartition.from_row_lengths,
|
|
'row_lengths': [[1, 2], [1, 0]]
|
|
},
|
|
{
|
|
'descr': 'negatve row_lengths',
|
|
'factory': RowPartition.from_row_lengths,
|
|
'row_lengths': [3, -1, 2]
|
|
},
|
|
|
|
# from_row_starts
|
|
{
|
|
'descr': 'bad rank for row_starts',
|
|
'factory': RowPartition.from_row_starts,
|
|
'nvals': 2,
|
|
'row_starts': [[1, 2], [3, 4]]
|
|
},
|
|
{
|
|
'descr': 'row_starts[0] != 0',
|
|
'factory': RowPartition.from_row_starts,
|
|
'nvals': 5,
|
|
'row_starts': [2, 3, 4]
|
|
},
|
|
{
|
|
'descr': 'non-monotonic-increasing row_starts',
|
|
'factory': RowPartition.from_row_starts,
|
|
'nvals': 4,
|
|
'row_starts': [0, 3, 2, 4]
|
|
},
|
|
{
|
|
'descr': 'row_starts[0] > nvals',
|
|
'factory': RowPartition.from_row_starts,
|
|
'nvals': 4,
|
|
'row_starts': [0, 2, 3, 5]
|
|
},
|
|
|
|
# from_row_limits
|
|
{
|
|
'descr': 'bad rank for row_limits',
|
|
'factory': RowPartition.from_row_limits,
|
|
'row_limits': [[1, 2], [3, 4]]
|
|
},
|
|
{
|
|
'descr': 'row_limits[0] < 0',
|
|
'factory': RowPartition.from_row_limits,
|
|
'row_limits': [-1, 3, 4]
|
|
},
|
|
{
|
|
'descr': 'non-monotonic-increasing row_limits',
|
|
'factory': RowPartition.from_row_limits,
|
|
'row_limits': [0, 3, 2, 4]
|
|
},
|
|
|
|
# from_uniform_row_length
|
|
{
|
|
'descr': 'rowlen * nrows != nvals (1)',
|
|
'factory': RowPartition.from_uniform_row_length,
|
|
'nvals': 5,
|
|
'uniform_row_length': 3
|
|
},
|
|
{
|
|
'descr': 'rowlen * nrows != nvals (2)',
|
|
'factory': RowPartition.from_uniform_row_length,
|
|
'nvals': 5,
|
|
'uniform_row_length': 6
|
|
},
|
|
{
|
|
'descr': 'rowlen * nrows != nvals (3)',
|
|
'factory': RowPartition.from_uniform_row_length,
|
|
'nvals': 6,
|
|
'uniform_row_length': 3,
|
|
'nrows': 3
|
|
},
|
|
{
|
|
'descr': 'rowlen must be a scalar',
|
|
'factory': RowPartition.from_uniform_row_length,
|
|
'nvals': 4,
|
|
'uniform_row_length': [2]
|
|
},
|
|
{
|
|
'descr': 'rowlen must be nonnegative',
|
|
'factory': RowPartition.from_uniform_row_length,
|
|
'nvals': 4,
|
|
'uniform_row_length': -1
|
|
},
|
|
])
|
|
def testFactoryValidation(self, descr, factory, **kwargs):
|
|
# When input tensors have shape information, some of these errors will be
|
|
# detected statically.
|
|
with self.assertRaises((errors.InvalidArgumentError, ValueError)):
|
|
partition = factory(**kwargs)
|
|
self.evaluate(partition.row_splits())
|
|
|
|
# Remove shape information (by wrapping tensors in placeholders), and check
|
|
# that we detect the errors when the graph is run.
|
|
if not context.executing_eagerly():
|
|
|
|
def wrap_arg(v):
|
|
return array_ops.placeholder_with_default(
|
|
constant_op.constant(v, dtype=dtypes.int64),
|
|
tensor_shape.TensorShape(None))
|
|
|
|
kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
|
|
|
|
with self.assertRaises(errors.InvalidArgumentError):
|
|
partition = factory(**kwargs)
|
|
self.evaluate(partition.row_splits())
|
|
|
|
@parameterized.named_parameters([
|
|
('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]),
|
|
['row_splits']),
|
|
('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]),
|
|
['row_splits', 'row_lengths']),
|
|
('FromValueRowIds',
|
|
lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]),
|
|
['row_splits', 'value_rowids', 'row_lengths', 'nrows']),
|
|
('FromRowStarts',
|
|
lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10),
|
|
['row_splits']),
|
|
('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]),
|
|
['row_splits']),
|
|
])
|
|
def testPrecomputedSplits(self, rp_factory, expected_encodings):
|
|
rp = rp_factory()
|
|
self.assertEqual(rp.has_precomputed_row_splits(),
|
|
'row_splits' in expected_encodings)
|
|
self.assertEqual(rp.has_precomputed_row_lengths(),
|
|
'row_lengths' in expected_encodings)
|
|
self.assertEqual(rp.has_precomputed_value_rowids(),
|
|
'value_rowids' in expected_encodings)
|
|
self.assertEqual(rp.has_precomputed_nrows(), 'nrows' in expected_encodings)
|
|
|
|
def testWithPrecomputedSplits(self):
|
|
rp = RowPartition.from_row_splits([0, 2, 8])
|
|
|
|
rp_with_row_splits = rp.with_precomputed_row_splits()
|
|
self.assertTrue(rp_with_row_splits.has_precomputed_row_splits())
|
|
|
|
self.assertFalse(rp.has_precomputed_row_lengths())
|
|
rp_with_row_lengths = rp.with_precomputed_row_lengths()
|
|
self.assertTrue(rp_with_row_lengths.has_precomputed_row_lengths())
|
|
|
|
self.assertFalse(rp.has_precomputed_value_rowids())
|
|
rp_with_value_rowids = rp.with_precomputed_value_rowids()
|
|
self.assertTrue(rp_with_value_rowids.has_precomputed_value_rowids())
|
|
|
|
self.assertFalse(rp.has_precomputed_nrows())
|
|
rp_with_nrows = rp.with_precomputed_nrows()
|
|
self.assertTrue(rp_with_nrows.has_precomputed_nrows())
|
|
|
|
@parameterized.named_parameters([
|
|
dict(
|
|
testcase_name='FromRowSplitsAndRowSplits',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
y=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
expected_encodings=['row_splits']),
|
|
dict(
|
|
testcase_name='FromRowSplitsAndUniformRowLength',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 6]),
|
|
y=lambda: RowPartition.from_uniform_row_length(3, nvals=6),
|
|
expected_encodings=['row_splits', 'uniform_row_length', 'nrows']),
|
|
dict(
|
|
testcase_name='FromRowSplitsAndRowLengths',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
y=lambda: RowPartition.from_row_lengths([3, 5]),
|
|
expected_encodings=['row_splits', 'row_lengths']),
|
|
dict(
|
|
testcase_name='FromRowSplitsAndValueRowIds',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]),
|
|
expected_encodings=[
|
|
'row_splits', 'row_lengths', 'value_rowids', 'nrows'
|
|
]),
|
|
dict(
|
|
testcase_name='FromRowSplitsAndRowSplitsPlusNRows',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
y=lambda: RowPartition.from_row_splits([0, 3, 8]).
|
|
with_precomputed_nrows(),
|
|
expected_encodings=['row_splits', 'nrows']),
|
|
])
|
|
def testMergePrecomputedEncodings(self, x, y, expected_encodings):
|
|
x = x()
|
|
y = y()
|
|
for validate in (True, False):
|
|
result = x.merge_precomputed_encodings(y, validate)
|
|
self.assertEqual(result.has_precomputed_row_splits(),
|
|
'row_splits' in expected_encodings)
|
|
self.assertEqual(result.has_precomputed_row_lengths(),
|
|
'row_lengths' in expected_encodings)
|
|
self.assertEqual(result.has_precomputed_value_rowids(),
|
|
'value_rowids' in expected_encodings)
|
|
self.assertEqual(result.has_precomputed_nrows(),
|
|
'nrows' in expected_encodings)
|
|
self.assertEqual(result.uniform_row_length() is not None,
|
|
'uniform_row_length' in expected_encodings)
|
|
for r in (x, y):
|
|
if (r.has_precomputed_row_splits() and
|
|
result.has_precomputed_row_splits()):
|
|
self.assertAllEqual(r.row_splits(), result.row_splits())
|
|
if (r.has_precomputed_row_lengths() and
|
|
result.has_precomputed_row_lengths()):
|
|
self.assertAllEqual(r.row_lengths(), result.row_lengths())
|
|
if (r.has_precomputed_value_rowids() and
|
|
result.has_precomputed_value_rowids()):
|
|
self.assertAllEqual(r.value_rowids(), result.value_rowids())
|
|
if r.has_precomputed_nrows() and result.has_precomputed_nrows():
|
|
self.assertAllEqual(r.nrows(), result.nrows())
|
|
if (r.uniform_row_length() is not None and
|
|
result.uniform_row_length() is not None):
|
|
self.assertAllEqual(r.uniform_row_length(),
|
|
result.uniform_row_length())
|
|
|
|
def testMergePrecomputedEncodingsFastPaths(self):
|
|
# Same object: x gets returned as-is.
|
|
x = RowPartition.from_row_splits([0, 3, 8, 8])
|
|
self.assertIs(x.merge_precomputed_encodings(x), x)
|
|
|
|
# Same encoding tensor objects: x gets returned as-is.
|
|
y = RowPartition.from_row_splits(x.row_splits(), validate=False)
|
|
self.assertIs(x.merge_precomputed_encodings(y), x)
|
|
|
|
def testMergePrecomputedEncodingsWithMatchingTensors(self):
|
|
# The encoding tensors for `a` are a superset of the encoding tensors
|
|
# for `b`, and where they overlap, they the same tensor objects.
|
|
a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4])
|
|
b = RowPartition.from_row_splits(a.row_splits(), validate=False)
|
|
self.assertIs(a.merge_precomputed_encodings(b), a)
|
|
self.assertIs(b.merge_precomputed_encodings(a), a)
|
|
self.assertIsNot(a, b)
|
|
|
|
@parameterized.named_parameters([
|
|
dict(
|
|
testcase_name='RowSplitMismatch',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]),
|
|
message='incompatible row_splits'),
|
|
dict(
|
|
testcase_name='RowLengthMismatch',
|
|
x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
|
|
y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]),
|
|
message='incompatible row_splits'), # row_splits is checked first
|
|
dict(
|
|
testcase_name='ValueRowIdMismatch',
|
|
x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]),
|
|
y=lambda: RowPartition.from_value_rowids([0, 3, 4]),
|
|
message='incompatible value_rowids'),
|
|
])
|
|
def testMergePrecomputedEncodingStaticErrors(self, x, y, message):
|
|
if context.executing_eagerly():
|
|
return
|
|
# Errors that are caught by static shape checks.
|
|
x = x()
|
|
y = y()
|
|
with self.assertRaisesRegex(ValueError, message):
|
|
x.merge_precomputed_encodings(y).row_splits()
|
|
with self.assertRaisesRegex(ValueError, message):
|
|
y.merge_precomputed_encodings(x).row_splits()
|
|
|
|
@parameterized.named_parameters([
|
|
dict(
|
|
testcase_name='NRowsMismatch',
|
|
x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
|
|
y=lambda: RowPartition.from_uniform_row_length(5, nvals=15),
|
|
message='incompatible nrows'),
|
|
dict(
|
|
testcase_name='UniformRowLengthMismatch',
|
|
x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
|
|
y=lambda: RowPartition.from_uniform_row_length(2, nvals=8),
|
|
message='incompatible uniform_row_length'),
|
|
dict(
|
|
testcase_name='RowSplitMismatch',
|
|
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
|
y=lambda: RowPartition.from_row_splits([0, 5, 8]),
|
|
message='incompatible row_splits'),
|
|
dict(
|
|
testcase_name='RowLengthMismatch',
|
|
x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
|
|
y=lambda: RowPartition.from_row_lengths([0, 0, 2]),
|
|
message='incompatible row_splits'), # row_splits is checked first
|
|
dict(
|
|
testcase_name='ValueRowIdMismatch',
|
|
x=lambda: RowPartition.from_value_rowids([0, 3, 3]),
|
|
y=lambda: RowPartition.from_value_rowids([0, 0, 3]),
|
|
message='incompatible row_splits'), # row_splits is checked first
|
|
])
|
|
def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message):
|
|
# Errors that are caught by runtime value checks.
|
|
x = x()
|
|
y = y()
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError, message):
|
|
self.evaluate(x.merge_precomputed_encodings(y).row_splits())
|
|
with self.assertRaisesRegex(errors.InvalidArgumentError, message):
|
|
self.evaluate(y.merge_precomputed_encodings(x).row_splits())
|
|
|
|
|
|
@test_util.run_all_in_graph_and_eager_modes
|
|
class RowPartitionSpecTest(test_util.TensorFlowTestCase,
|
|
parameterized.TestCase):
|
|
|
|
def testDefaultConstruction(self):
|
|
spec = RowPartitionSpec()
|
|
self.assertEqual(spec.nrows, None)
|
|
self.assertEqual(spec.nvals, None)
|
|
self.assertEqual(spec.uniform_row_length, None)
|
|
self.assertEqual(spec.dtype, dtypes.int64)
|
|
|
|
@parameterized.parameters([
|
|
(None, None, None, dtypes.int64, None, None, None, dtypes.int64),
|
|
(5, None, None, dtypes.int32, 5, None, None, dtypes.int32),
|
|
(None, 20, None, dtypes.int64, None, 20, None, dtypes.int64),
|
|
(None, None, 8, dtypes.int64, None, None, 8, dtypes.int64),
|
|
(5, None, 8, dtypes.int64, 5, 40, 8, dtypes.int64), # nvals inferred
|
|
(None, 20, 4, dtypes.int32, 5, 20, 4, dtypes.int32), # nrows inferred
|
|
(0, None, None, dtypes.int32, 0, 0, None, dtypes.int32), # nvals inferred
|
|
(None, None, 0, dtypes.int32, None, 0, 0, dtypes.int32), # nvals inferred
|
|
]) # pyformat: disable
|
|
def testConstruction(self, nrows, nvals, uniform_row_length, dtype,
|
|
expected_nrows, expected_nvals,
|
|
expected_uniform_row_length, expected_dtype):
|
|
spec = RowPartitionSpec(nrows, nvals, uniform_row_length, dtype)
|
|
self.assertEqual(spec.nrows, expected_nrows)
|
|
self.assertEqual(spec.nvals, expected_nvals)
|
|
self.assertEqual(spec.uniform_row_length, expected_uniform_row_length)
|
|
self.assertEqual(spec.dtype, expected_dtype)
|
|
|
|
@parameterized.parameters([
|
|
dict(dtype=dtypes.float32, error='dtype must be tf.int32 or tf.int64'),
|
|
dict(nrows=0, nvals=5, error='.* not compatible .*'),
|
|
dict(uniform_row_length=0, nvals=5, error='.* not compatible .*'),
|
|
dict(nvals=11, uniform_row_length=5, error='.* not compatible .*'),
|
|
dict(
|
|
nrows=8, nvals=10, uniform_row_length=5,
|
|
error='.* not compatible .*'),
|
|
])
|
|
def testConstructionError(self,
|
|
nrows=None,
|
|
nvals=None,
|
|
uniform_row_length=None,
|
|
dtype=dtypes.int64,
|
|
error=None):
|
|
with self.assertRaisesRegex(ValueError, error):
|
|
RowPartitionSpec(nrows, nvals, uniform_row_length, dtype)
|
|
|
|
def testValueType(self):
|
|
spec = RowPartitionSpec()
|
|
self.assertEqual(spec.value_type, RowPartition)
|
|
|
|
@parameterized.parameters([
|
|
dict(
|
|
spec=RowPartitionSpec(),
|
|
expected=(tensor_shape.TensorShape([None]),
|
|
tensor_shape.TensorShape([None]),
|
|
tensor_shape.TensorShape([None]), dtypes.int64)),
|
|
dict(
|
|
spec=RowPartitionSpec(dtype=dtypes.int32),
|
|
expected=(tensor_shape.TensorShape([None]),
|
|
tensor_shape.TensorShape([None]),
|
|
tensor_shape.TensorShape([None]), dtypes.int32)),
|
|
dict(
|
|
spec=RowPartitionSpec(nrows=8, nvals=13),
|
|
expected=(tensor_shape.TensorShape([8]),
|
|
tensor_shape.TensorShape([13]),
|
|
tensor_shape.TensorShape([None]), dtypes.int64)),
|
|
dict(
|
|
spec=RowPartitionSpec(nrows=8, uniform_row_length=2),
|
|
expected=(
|
|
tensor_shape.TensorShape([8]),
|
|
tensor_shape.TensorShape([16]), # inferred
|
|
tensor_shape.TensorShape([2]),
|
|
dtypes.int64)),
|
|
])
|
|
def testSerialize(self, spec, expected):
|
|
serialization = spec._serialize()
|
|
# TensorShape has an unconventional definition of equality, so we can't use
|
|
# assertEqual directly here. But repr() is deterministic and lossless for
|
|
# the expected values, so we can use that instead.
|
|
self.assertEqual(repr(serialization), repr(expected))
|
|
|
|
@parameterized.parameters([
|
|
dict(
|
|
spec=RowPartitionSpec(),
|
|
expected=tensor_spec.TensorSpec([None], dtypes.int64)),
|
|
dict(
|
|
spec=RowPartitionSpec(dtype=dtypes.int32),
|
|
expected=tensor_spec.TensorSpec([None], dtypes.int32)),
|
|
dict(
|
|
spec=RowPartitionSpec(nrows=17, dtype=dtypes.int32),
|
|
expected=tensor_spec.TensorSpec([18], dtypes.int32)),
|
|
dict(
|
|
spec=RowPartitionSpec(nvals=10, uniform_row_length=2),
|
|
expected=tensor_spec.TensorSpec([6], dtypes.int64)), # inferred nrow
|
|
])
|
|
def testComponentSpecs(self, spec, expected):
|
|
self.assertEqual(spec._component_specs, expected)
|
|
|
|
@parameterized.parameters([
|
|
dict(
|
|
rp_factory=lambda: RowPartition.from_row_splits([0, 3, 7]),
|
|
components=[0, 3, 7]),
|
|
])
|
|
def testToFromComponents(self, rp_factory, components):
|
|
rp = rp_factory()
|
|
spec = rp._type_spec
|
|
actual_components = spec._to_components(rp)
|
|
self.assertAllEqual(actual_components, components)
|
|
rp_reconstructed = spec._from_components(actual_components)
|
|
_assert_row_partition_equal(self, rp, rp_reconstructed)
|
|
|
|
@parameterized.parameters([
|
|
(RowPartitionSpec(), RowPartitionSpec()),
|
|
(RowPartitionSpec(nrows=8), RowPartitionSpec(nrows=8)),
|
|
(RowPartitionSpec(nrows=8), RowPartitionSpec(nrows=None)),
|
|
(RowPartitionSpec(nvals=8), RowPartitionSpec(nvals=8)),
|
|
(RowPartitionSpec(nvals=8), RowPartitionSpec(nvals=None)),
|
|
(RowPartitionSpec(uniform_row_length=8),
|
|
RowPartitionSpec(uniform_row_length=8)),
|
|
(RowPartitionSpec(uniform_row_length=8),
|
|
RowPartitionSpec(uniform_row_length=None)),
|
|
(RowPartitionSpec(nvals=12), RowPartitionSpec(uniform_row_length=3)),
|
|
(RowPartitionSpec(nrows=12), RowPartitionSpec(uniform_row_length=72)),
|
|
(RowPartitionSpec(nrows=5), RowPartitionSpec(nvals=15)),
|
|
(RowPartitionSpec(nvals=0), RowPartitionSpec(nrows=0)),
|
|
(RowPartitionSpec(nvals=0), RowPartitionSpec(uniform_row_length=0)),
|
|
])
|
|
def testIsCompatibleWith(self, spec1, spec2):
|
|
self.assertTrue(spec1.is_compatible_with(spec2))
|
|
|
|
@parameterized.parameters([
|
|
(RowPartitionSpec(), RowPartitionSpec(dtype=dtypes.int32)),
|
|
(RowPartitionSpec(nvals=5), RowPartitionSpec(uniform_row_length=3)),
|
|
(RowPartitionSpec(nrows=7,
|
|
nvals=12), RowPartitionSpec(uniform_row_length=3)),
|
|
(RowPartitionSpec(nvals=5), RowPartitionSpec(nrows=0)),
|
|
(RowPartitionSpec(nvals=5), RowPartitionSpec(uniform_row_length=0)),
|
|
])
|
|
def testIsNotCompatibleWith(self, spec1, spec2):
|
|
self.assertFalse(spec1.is_compatible_with(spec2))
|
|
|
|
@parameterized.parameters([
|
|
dict(
|
|
spec1=RowPartitionSpec(nrows=8, nvals=3, dtype=dtypes.int32),
|
|
spec2=RowPartitionSpec(nrows=8, nvals=3, dtype=dtypes.int32),
|
|
expected=RowPartitionSpec(nrows=8, nvals=3, dtype=dtypes.int32)),
|
|
dict(
|
|
spec1=RowPartitionSpec(nrows=8, nvals=None),
|
|
spec2=RowPartitionSpec(nrows=None, nvals=8),
|
|
expected=RowPartitionSpec(nrows=None, nvals=None)),
|
|
dict(
|
|
spec1=RowPartitionSpec(nrows=8, nvals=33),
|
|
spec2=RowPartitionSpec(nrows=3, nvals=13),
|
|
expected=RowPartitionSpec(nrows=None, nvals=None)),
|
|
dict(
|
|
spec1=RowPartitionSpec(nrows=12, uniform_row_length=3),
|
|
spec2=RowPartitionSpec(nrows=3, uniform_row_length=3),
|
|
expected=RowPartitionSpec(nrows=None, uniform_row_length=3)),
|
|
dict(
|
|
spec1=RowPartitionSpec(5, 35, 7),
|
|
spec2=RowPartitionSpec(8, 80, 10),
|
|
expected=RowPartitionSpec(None, None, None)),
|
|
])
|
|
def testMostSpecificCompatibleType(self, spec1, spec2, expected):
|
|
actual = spec1.most_specific_compatible_type(spec2)
|
|
self.assertEqual(actual, expected)
|
|
|
|
@parameterized.parameters([
|
|
(RowPartitionSpec(), RowPartitionSpec(dtype=dtypes.int32)),
|
|
])
|
|
def testMostSpecificCompatibleTypeError(self, spec1, spec2):
|
|
with self.assertRaisesRegex(ValueError, 'not compatible'):
|
|
spec1.most_specific_compatible_type(spec2)
|
|
|
|
def testFromValue(self):
|
|
self.assertEqual(
|
|
RowPartitionSpec.from_value(RowPartition.from_row_splits([0, 2, 8, 8])),
|
|
RowPartitionSpec(nrows=3))
|
|
self.assertEqual(
|
|
RowPartitionSpec.from_value(
|
|
RowPartition.from_row_lengths([5, 3, 0, 2])),
|
|
RowPartitionSpec(nrows=4))
|
|
self.assertEqual(
|
|
RowPartitionSpec.from_value(
|
|
RowPartition.from_value_rowids([0, 2, 2, 8])),
|
|
RowPartitionSpec(nrows=9, nvals=4))
|
|
self.assertEqual(
|
|
RowPartitionSpec.from_value(
|
|
RowPartition.from_uniform_row_length(
|
|
nvals=12, uniform_row_length=3)),
|
|
RowPartitionSpec(nvals=12, uniform_row_length=3))
|
|
|
|
|
|
def _assert_row_partition_equal(test_class, actual, expected):
|
|
assert isinstance(test_class, test_util.TensorFlowTestCase)
|
|
assert isinstance(actual, RowPartition)
|
|
assert isinstance(expected, RowPartition)
|
|
|
|
test_class.assertEqual(actual.has_precomputed_row_splits(),
|
|
expected.has_precomputed_row_splits())
|
|
test_class.assertEqual(actual.has_precomputed_row_lengths(),
|
|
expected.has_precomputed_row_lengths())
|
|
test_class.assertEqual(actual.has_precomputed_value_rowids(),
|
|
expected.has_precomputed_value_rowids())
|
|
test_class.assertEqual(actual.has_precomputed_nrows(),
|
|
expected.has_precomputed_nrows())
|
|
test_class.assertEqual(actual.uniform_row_length() is None,
|
|
expected.uniform_row_length() is None)
|
|
|
|
if expected.has_precomputed_row_splits():
|
|
test_class.assertAllEqual(actual.row_splits(), expected.row_splits())
|
|
if expected.has_precomputed_row_lengths():
|
|
test_class.assertAllEqual(actual.row_lengths(), expected.row_lengths())
|
|
if expected.has_precomputed_value_rowids():
|
|
test_class.assertAllEqual(actual.value_rowids(), expected.value_rowids())
|
|
if expected.has_precomputed_nrows():
|
|
test_class.assertAllEqual(actual.nrows(), expected.nrows())
|
|
if expected.uniform_row_length() is not None:
|
|
test_class.assertAllEqual(actual.uniform_row_length(),
|
|
expected.uniform_row_length())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
googletest.main()
|