* Add support for explicit management of precomputed encodings.
PiperOrigin-RevId: 300557760 Change-Id: I21a5bff5bec95d5fef20f9a6f292f54e650d0670
This commit is contained in:
parent
da640e4e0c
commit
b88330d404
@ -102,7 +102,8 @@ class RowPartition(composite_tensor.CompositeTensor):
|
||||
avoid unnecessary recomputation in eager mode. (In graph mode, optimizations
|
||||
such as common subexpression elimination will typically prevent these
|
||||
unnecessary recomputations.) To check which encodings are precomputed, use
|
||||
`RowPartition.has_precomputed_<encoding>`.
|
||||
`RowPartition.has_precomputed_<encoding>`. To cache an additional
|
||||
encoding, use `RowPartition.with_precomputed_<encoding>`.
|
||||
"""
|
||||
|
||||
#=============================================================================
|
||||
@ -897,6 +898,115 @@ class RowPartition(composite_tensor.CompositeTensor):
|
||||
"""
|
||||
return self._nrows is not None
|
||||
|
||||
def with_precomputed_row_splits(self):
|
||||
"""Returns a copy of `self` with `row_splits` precomputed."""
|
||||
return RowPartition(
|
||||
row_splits=self.row_splits(),
|
||||
row_lengths=self._row_lengths,
|
||||
value_rowids=self._value_rowids,
|
||||
nrows=self._nrows,
|
||||
uniform_row_length=self._uniform_row_length,
|
||||
internal=_row_partition_factory_key)
|
||||
|
||||
def with_precomputed_row_lengths(self):
|
||||
"""Returns a copy of `self` with `row_lengths` precomputed."""
|
||||
return RowPartition(
|
||||
row_splits=self._row_splits,
|
||||
row_lengths=self.row_lengths(),
|
||||
value_rowids=self._value_rowids,
|
||||
nrows=self._nrows,
|
||||
uniform_row_length=self._uniform_row_length,
|
||||
internal=_row_partition_factory_key)
|
||||
|
||||
def with_precomputed_value_rowids(self):
|
||||
"""Returns a copy of `self` with `value_rowids` precomputed."""
|
||||
return RowPartition(
|
||||
row_splits=self._row_splits,
|
||||
row_lengths=self._row_lengths,
|
||||
value_rowids=self.value_rowids(),
|
||||
nrows=self._nrows,
|
||||
uniform_row_length=self._uniform_row_length,
|
||||
internal=_row_partition_factory_key)
|
||||
|
||||
def with_precomputed_nrows(self):
|
||||
"""Returns a copy of `self` with `nrows` precomputed."""
|
||||
return RowPartition(
|
||||
row_splits=self._row_splits,
|
||||
row_lengths=self._row_lengths,
|
||||
value_rowids=self._value_rowids,
|
||||
nrows=self.nrows(),
|
||||
uniform_row_length=self._uniform_row_length,
|
||||
internal=_row_partition_factory_key)
|
||||
|
||||
def merge_precomputed_encodings(self, other, validate=True):
|
||||
"""Returns a RowPartition that merges encodings from `self` and `other`.
|
||||
|
||||
Requires that `self` and `other` describe the same partition.
|
||||
|
||||
Args:
|
||||
other: A `RowPartition` that encodes the same partition as `self`.
|
||||
validate: If true, then add runtime checks to verify that `self` and
|
||||
`other` encode the same row partition.
|
||||
|
||||
Returns:
|
||||
A `RowPartition`.
|
||||
"""
|
||||
# pylint: disable=protected-access
|
||||
if (self is other or # Fast path if row partitions are equal.
|
||||
(self._row_splits is other._row_splits and
|
||||
self._row_lengths is other._row_lengths and
|
||||
self._value_rowids is other._value_rowids and
|
||||
self._nrows is other._nrows and
|
||||
self._uniform_row_length is other._uniform_row_length)):
|
||||
return self
|
||||
|
||||
# Merge the component tensors. We only need to validate one encoding.
|
||||
# We merge less-expensive encodings first (to avoid expensive validation).
|
||||
nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows",
|
||||
validate)
|
||||
uniform_row_length, uniform_row_length_validated = _merge_tensors(
|
||||
self._uniform_row_length, other._uniform_row_length,
|
||||
"uniform_row_length", validate)
|
||||
if uniform_row_length_validated and nrows_validated:
|
||||
validate = False # Validation complete.
|
||||
row_splits, row_splits_validated = _merge_tensors(self._row_splits,
|
||||
other._row_splits,
|
||||
"row_splits", validate)
|
||||
if row_splits_validated:
|
||||
validate = False # Validation complete.
|
||||
row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths,
|
||||
other._row_lengths,
|
||||
"row_lengths", validate)
|
||||
if row_lengths_validated:
|
||||
validate = False # Validation complete.
|
||||
value_rowids, value_rowids_validated = _merge_tensors(
|
||||
self._value_rowids, other._value_rowids, "value_rowids", validate)
|
||||
if value_rowids_validated and nrows_validated:
|
||||
validate = False # Validation complete.
|
||||
# TODO(edloper): If we make the row_splits encoding optional, then there
|
||||
# will be cases where we need to do validation at this point -- e.g. if
|
||||
# self has only row_splits and other has only value_rowids. But for
|
||||
# now, we are guaranteed to have done validation by this point.
|
||||
|
||||
# Avoid creating new RowPartition objects if we don't need to.
|
||||
if (row_splits is self._row_splits and row_lengths is self._row_lengths and
|
||||
value_rowids is self._value_rowids and nrows is self._nrows and
|
||||
uniform_row_length is self._uniform_row_length):
|
||||
return self
|
||||
if (row_splits is other._row_splits and
|
||||
row_lengths is other._row_lengths and
|
||||
value_rowids is other._value_rowids and nrows is other._nrows and
|
||||
uniform_row_length is other._uniform_row_length):
|
||||
return other
|
||||
|
||||
return RowPartition(
|
||||
row_splits=row_splits,
|
||||
row_lengths=row_lengths,
|
||||
value_rowids=value_rowids,
|
||||
nrows=nrows,
|
||||
uniform_row_length=uniform_row_length,
|
||||
internal=_row_partition_factory_key)
|
||||
|
||||
#=============================================================================
|
||||
# Composite Tensor
|
||||
#=============================================================================
|
||||
@ -1076,4 +1186,38 @@ def _cast_if_not_none(tensor, dtype):
|
||||
return None if tensor is None else math_ops.cast(tensor, dtype)
|
||||
|
||||
|
||||
def _merge_tensors(t1, t2, name, validate):
|
||||
"""Merge two optional Tensors with equal values into a single Tensor.
|
||||
|
||||
Args:
|
||||
t1: tf.Tensor or None
|
||||
t2: tf.Tensor or None
|
||||
name: A name for the tensors (for error messages)
|
||||
validate: If true, then check that `t1` is compatible with `t2` (if both are
|
||||
non-None).
|
||||
|
||||
Returns:
|
||||
A pair `(merged_value, validated)`:
|
||||
* `merged_value` is `t1` if it is not None; or `t2` otherwise.
|
||||
* `validated` is true if we validated that t1 and t2 are equal (either
|
||||
by adding a check, or because t1 is t2).
|
||||
"""
|
||||
if t1 is None:
|
||||
return t2, False
|
||||
elif t2 is None:
|
||||
return t1, False
|
||||
elif t1 is t2:
|
||||
return t1, True
|
||||
else:
|
||||
err_msg = ("RowPartition.merge_precomuted_encodings: partitons "
|
||||
"have incompatible %s" % name)
|
||||
if not t1.shape.is_compatible_with(t2.shape):
|
||||
raise ValueError(err_msg)
|
||||
if validate:
|
||||
checks = [check_ops.assert_equal(t1, t2, message=err_msg)]
|
||||
return control_flow_ops.with_dependencies(checks, t1), True
|
||||
else:
|
||||
return t1, False
|
||||
|
||||
|
||||
_row_partition_factory_key = object() # unique private object
|
||||
|
@ -481,6 +481,192 @@ class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
partition = factory(**kwargs)
|
||||
self.evaluate(partition.row_splits())
|
||||
|
||||
@parameterized.named_parameters([
|
||||
('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]),
|
||||
['row_splits']),
|
||||
('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]),
|
||||
['row_splits', 'row_lengths']),
|
||||
('FromValueRowIds',
|
||||
lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]),
|
||||
['row_splits', 'value_rowids', 'row_lengths', 'nrows']),
|
||||
('FromRowStarts',
|
||||
lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10),
|
||||
['row_splits']),
|
||||
('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]),
|
||||
['row_splits']),
|
||||
])
|
||||
def testPrecomputedSplits(self, rp_factory, expected_encodings):
|
||||
rp = rp_factory()
|
||||
self.assertEqual(rp.has_precomputed_row_splits(),
|
||||
'row_splits' in expected_encodings)
|
||||
self.assertEqual(rp.has_precomputed_row_lengths(),
|
||||
'row_lengths' in expected_encodings)
|
||||
self.assertEqual(rp.has_precomputed_value_rowids(),
|
||||
'value_rowids' in expected_encodings)
|
||||
self.assertEqual(rp.has_precomputed_nrows(), 'nrows' in expected_encodings)
|
||||
|
||||
def testWithPrecomputedSplits(self):
|
||||
rp = RowPartition.from_row_splits([0, 2, 8])
|
||||
|
||||
rp_with_row_splits = rp.with_precomputed_row_splits()
|
||||
self.assertTrue(rp_with_row_splits.has_precomputed_row_splits())
|
||||
|
||||
self.assertFalse(rp.has_precomputed_row_lengths())
|
||||
rp_with_row_lengths = rp.with_precomputed_row_lengths()
|
||||
self.assertTrue(rp_with_row_lengths.has_precomputed_row_lengths())
|
||||
|
||||
self.assertFalse(rp.has_precomputed_value_rowids())
|
||||
rp_with_value_rowids = rp.with_precomputed_value_rowids()
|
||||
self.assertTrue(rp_with_value_rowids.has_precomputed_value_rowids())
|
||||
|
||||
self.assertFalse(rp.has_precomputed_nrows())
|
||||
rp_with_nrows = rp.with_precomputed_nrows()
|
||||
self.assertTrue(rp_with_nrows.has_precomputed_nrows())
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='FromRowSplitsAndRowSplits',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
y=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
expected_encodings=['row_splits']),
|
||||
dict(
|
||||
testcase_name='FromRowSplitsAndUniformRowLength',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 6]),
|
||||
y=lambda: RowPartition.from_uniform_row_length(3, nvals=6),
|
||||
expected_encodings=['row_splits', 'uniform_row_length', 'nrows']),
|
||||
dict(
|
||||
testcase_name='FromRowSplitsAndRowLengths',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
y=lambda: RowPartition.from_row_lengths([3, 5]),
|
||||
expected_encodings=['row_splits', 'row_lengths']),
|
||||
dict(
|
||||
testcase_name='FromRowSplitsAndValueRowIds',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]),
|
||||
expected_encodings=[
|
||||
'row_splits', 'row_lengths', 'value_rowids', 'nrows'
|
||||
]),
|
||||
dict(
|
||||
testcase_name='FromRowSplitsAndRowSplitsPlusNRows',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
y=lambda: RowPartition.from_row_splits([0, 3, 8]).
|
||||
with_precomputed_nrows(),
|
||||
expected_encodings=['row_splits', 'nrows']),
|
||||
])
|
||||
def testMergePrecomputedEncodings(self, x, y, expected_encodings):
|
||||
x = x()
|
||||
y = y()
|
||||
for validate in (True, False):
|
||||
result = x.merge_precomputed_encodings(y, validate)
|
||||
self.assertEqual(result.has_precomputed_row_splits(),
|
||||
'row_splits' in expected_encodings)
|
||||
self.assertEqual(result.has_precomputed_row_lengths(),
|
||||
'row_lengths' in expected_encodings)
|
||||
self.assertEqual(result.has_precomputed_value_rowids(),
|
||||
'value_rowids' in expected_encodings)
|
||||
self.assertEqual(result.has_precomputed_nrows(),
|
||||
'nrows' in expected_encodings)
|
||||
self.assertEqual(result.uniform_row_length() is not None,
|
||||
'uniform_row_length' in expected_encodings)
|
||||
for r in (x, y):
|
||||
if (r.has_precomputed_row_splits() and
|
||||
result.has_precomputed_row_splits()):
|
||||
self.assertAllEqual(r.row_splits(), result.row_splits())
|
||||
if (r.has_precomputed_row_lengths() and
|
||||
result.has_precomputed_row_lengths()):
|
||||
self.assertAllEqual(r.row_lengths(), result.row_lengths())
|
||||
if (r.has_precomputed_value_rowids() and
|
||||
result.has_precomputed_value_rowids()):
|
||||
self.assertAllEqual(r.value_rowids(), result.value_rowids())
|
||||
if r.has_precomputed_nrows() and result.has_precomputed_nrows():
|
||||
self.assertAllEqual(r.nrows(), result.nrows())
|
||||
if (r.uniform_row_length() is not None and
|
||||
result.uniform_row_length() is not None):
|
||||
self.assertAllEqual(r.uniform_row_length(),
|
||||
result.uniform_row_length())
|
||||
|
||||
def testMergePrecomputedEncodingsFastPaths(self):
|
||||
# Same object: x gets returned as-is.
|
||||
x = RowPartition.from_row_splits([0, 3, 8, 8])
|
||||
self.assertIs(x.merge_precomputed_encodings(x), x)
|
||||
|
||||
# Same encoding tensor objects: x gets returned as-is.
|
||||
y = RowPartition.from_row_splits(x.row_splits(), validate=False)
|
||||
self.assertIs(x.merge_precomputed_encodings(y), x)
|
||||
|
||||
def testMergePrecomputedEncodingsWithMatchingTensors(self):
|
||||
# The encoding tensors for `a` are a superset of the encoding tensors
|
||||
# for `b`, and where they overlap, they the same tensor objects.
|
||||
a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4])
|
||||
b = RowPartition.from_row_splits(a.row_splits(), validate=False)
|
||||
self.assertIs(a.merge_precomputed_encodings(b), a)
|
||||
self.assertIs(b.merge_precomputed_encodings(a), a)
|
||||
self.assertIsNot(a, b)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='RowSplitMismatch',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]),
|
||||
message='incompatible row_splits'),
|
||||
dict(
|
||||
testcase_name='RowLengthMismatch',
|
||||
x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
|
||||
y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]),
|
||||
message='incompatible row_splits'), # row_splits is checked first
|
||||
dict(
|
||||
testcase_name='ValueRowIdMismatch',
|
||||
x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]),
|
||||
y=lambda: RowPartition.from_value_rowids([0, 3, 4]),
|
||||
message='incompatible value_rowids'),
|
||||
])
|
||||
def testMergePrecomputedEncodingStaticErrors(self, x, y, message):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
# Errors that are caught by static shape checks.
|
||||
x = x()
|
||||
y = y()
|
||||
with self.assertRaisesRegexp(ValueError, message):
|
||||
x.merge_precomputed_encodings(y).row_splits()
|
||||
with self.assertRaisesRegexp(ValueError, message):
|
||||
y.merge_precomputed_encodings(x).row_splits()
|
||||
|
||||
@parameterized.named_parameters([
|
||||
dict(
|
||||
testcase_name='NRowsMismatch',
|
||||
x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
|
||||
y=lambda: RowPartition.from_uniform_row_length(5, nvals=15),
|
||||
message='incompatible nrows'),
|
||||
dict(
|
||||
testcase_name='UniformRowLengthMismatch',
|
||||
x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
|
||||
y=lambda: RowPartition.from_uniform_row_length(2, nvals=8),
|
||||
message='incompatible uniform_row_length'),
|
||||
dict(
|
||||
testcase_name='RowSplitMismatch',
|
||||
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
|
||||
y=lambda: RowPartition.from_row_splits([0, 5, 8]),
|
||||
message='incompatible row_splits'),
|
||||
dict(
|
||||
testcase_name='RowLengthMismatch',
|
||||
x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
|
||||
y=lambda: RowPartition.from_row_lengths([0, 0, 2]),
|
||||
message='incompatible row_splits'), # row_splits is checked first
|
||||
dict(
|
||||
testcase_name='ValueRowIdMismatch',
|
||||
x=lambda: RowPartition.from_value_rowids([0, 3, 3]),
|
||||
y=lambda: RowPartition.from_value_rowids([0, 0, 3]),
|
||||
message='incompatible row_splits'), # row_splits is checked first
|
||||
])
|
||||
def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message):
|
||||
# Errors that are caught by runtime value checks.
|
||||
x = x()
|
||||
y = y()
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, message):
|
||||
self.evaluate(x.merge_precomputed_encodings(y).row_splits())
|
||||
with self.assertRaisesRegexp(errors.InvalidArgumentError, message):
|
||||
self.evaluate(y.merge_precomputed_encodings(x).row_splits())
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RowPartitionSpecTest(test_util.TensorFlowTestCase,
|
||||
|
Loading…
x
Reference in New Issue
Block a user