* Add support for explicit management of precomputed encodings.

PiperOrigin-RevId: 300557760
Change-Id: I21a5bff5bec95d5fef20f9a6f292f54e650d0670
This commit is contained in:
Edward Loper 2020-03-12 08:55:11 -07:00 committed by TensorFlower Gardener
parent da640e4e0c
commit b88330d404
2 changed files with 331 additions and 1 deletions

View File

@ -102,7 +102,8 @@ class RowPartition(composite_tensor.CompositeTensor):
avoid unnecessary recomputation in eager mode. (In graph mode, optimizations
such as common subexpression elimination will typically prevent these
unnecessary recomputations.) To check which encodings are precomputed, use
`RowPartition.has_precomputed_<encoding>`.
`RowPartition.has_precomputed_<encoding>`. To cache an additional
encoding, use `RowPartition.with_precomputed_<encoding>`.
"""
#=============================================================================
@ -897,6 +898,115 @@ class RowPartition(composite_tensor.CompositeTensor):
"""
return self._nrows is not None
def with_precomputed_row_splits(self):
"""Returns a copy of `self` with `row_splits` precomputed."""
return RowPartition(
row_splits=self.row_splits(),
row_lengths=self._row_lengths,
value_rowids=self._value_rowids,
nrows=self._nrows,
uniform_row_length=self._uniform_row_length,
internal=_row_partition_factory_key)
def with_precomputed_row_lengths(self):
"""Returns a copy of `self` with `row_lengths` precomputed."""
return RowPartition(
row_splits=self._row_splits,
row_lengths=self.row_lengths(),
value_rowids=self._value_rowids,
nrows=self._nrows,
uniform_row_length=self._uniform_row_length,
internal=_row_partition_factory_key)
def with_precomputed_value_rowids(self):
"""Returns a copy of `self` with `value_rowids` precomputed."""
return RowPartition(
row_splits=self._row_splits,
row_lengths=self._row_lengths,
value_rowids=self.value_rowids(),
nrows=self._nrows,
uniform_row_length=self._uniform_row_length,
internal=_row_partition_factory_key)
def with_precomputed_nrows(self):
"""Returns a copy of `self` with `nrows` precomputed."""
return RowPartition(
row_splits=self._row_splits,
row_lengths=self._row_lengths,
value_rowids=self._value_rowids,
nrows=self.nrows(),
uniform_row_length=self._uniform_row_length,
internal=_row_partition_factory_key)
def merge_precomputed_encodings(self, other, validate=True):
"""Returns a RowPartition that merges encodings from `self` and `other`.
Requires that `self` and `other` describe the same partition.
Args:
other: A `RowPartition` that encodes the same partition as `self`.
validate: If true, then add runtime checks to verify that `self` and
`other` encode the same row partition.
Returns:
A `RowPartition`.
"""
# pylint: disable=protected-access
if (self is other or # Fast path if row partitions are equal.
(self._row_splits is other._row_splits and
self._row_lengths is other._row_lengths and
self._value_rowids is other._value_rowids and
self._nrows is other._nrows and
self._uniform_row_length is other._uniform_row_length)):
return self
# Merge the component tensors. We only need to validate one encoding.
# We merge less-expensive encodings first (to avoid expensive validation).
nrows, nrows_validated = _merge_tensors(self._nrows, other._nrows, "nrows",
validate)
uniform_row_length, uniform_row_length_validated = _merge_tensors(
self._uniform_row_length, other._uniform_row_length,
"uniform_row_length", validate)
if uniform_row_length_validated and nrows_validated:
validate = False # Validation complete.
row_splits, row_splits_validated = _merge_tensors(self._row_splits,
other._row_splits,
"row_splits", validate)
if row_splits_validated:
validate = False # Validation complete.
row_lengths, row_lengths_validated = _merge_tensors(self._row_lengths,
other._row_lengths,
"row_lengths", validate)
if row_lengths_validated:
validate = False # Validation complete.
value_rowids, value_rowids_validated = _merge_tensors(
self._value_rowids, other._value_rowids, "value_rowids", validate)
if value_rowids_validated and nrows_validated:
validate = False # Validation complete.
# TODO(edloper): If we make the row_splits encoding optional, then there
# will be cases where we need to do validation at this point -- e.g. if
# self has only row_splits and other has only value_rowids. But for
# now, we are guaranteed to have done validation by this point.
# Avoid creating new RowPartition objects if we don't need to.
if (row_splits is self._row_splits and row_lengths is self._row_lengths and
value_rowids is self._value_rowids and nrows is self._nrows and
uniform_row_length is self._uniform_row_length):
return self
if (row_splits is other._row_splits and
row_lengths is other._row_lengths and
value_rowids is other._value_rowids and nrows is other._nrows and
uniform_row_length is other._uniform_row_length):
return other
return RowPartition(
row_splits=row_splits,
row_lengths=row_lengths,
value_rowids=value_rowids,
nrows=nrows,
uniform_row_length=uniform_row_length,
internal=_row_partition_factory_key)
#=============================================================================
# Composite Tensor
#=============================================================================
@ -1076,4 +1186,38 @@ def _cast_if_not_none(tensor, dtype):
return None if tensor is None else math_ops.cast(tensor, dtype)
def _merge_tensors(t1, t2, name, validate):
"""Merge two optional Tensors with equal values into a single Tensor.
Args:
t1: tf.Tensor or None
t2: tf.Tensor or None
name: A name for the tensors (for error messages)
validate: If true, then check that `t1` is compatible with `t2` (if both are
non-None).
Returns:
A pair `(merged_value, validated)`:
* `merged_value` is `t1` if it is not None; or `t2` otherwise.
* `validated` is true if we validated that t1 and t2 are equal (either
by adding a check, or because t1 is t2).
"""
if t1 is None:
return t2, False
elif t2 is None:
return t1, False
elif t1 is t2:
return t1, True
else:
err_msg = ("RowPartition.merge_precomuted_encodings: partitons "
"have incompatible %s" % name)
if not t1.shape.is_compatible_with(t2.shape):
raise ValueError(err_msg)
if validate:
checks = [check_ops.assert_equal(t1, t2, message=err_msg)]
return control_flow_ops.with_dependencies(checks, t1), True
else:
return t1, False
_row_partition_factory_key = object() # unique private object

View File

@ -481,6 +481,192 @@ class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase):
partition = factory(**kwargs)
self.evaluate(partition.row_splits())
@parameterized.named_parameters([
('FromRowSplits', lambda: RowPartition.from_row_splits([0, 2, 8]),
['row_splits']),
('FromRowLengths', lambda: RowPartition.from_row_lengths([3, 0, 8]),
['row_splits', 'row_lengths']),
('FromValueRowIds',
lambda: RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4]),
['row_splits', 'value_rowids', 'row_lengths', 'nrows']),
('FromRowStarts',
lambda: RowPartition.from_row_starts([0, 3, 7], nvals=10),
['row_splits']),
('FromRowLimits', lambda: RowPartition.from_row_limits([3, 7, 10]),
['row_splits']),
])
def testPrecomputedSplits(self, rp_factory, expected_encodings):
rp = rp_factory()
self.assertEqual(rp.has_precomputed_row_splits(),
'row_splits' in expected_encodings)
self.assertEqual(rp.has_precomputed_row_lengths(),
'row_lengths' in expected_encodings)
self.assertEqual(rp.has_precomputed_value_rowids(),
'value_rowids' in expected_encodings)
self.assertEqual(rp.has_precomputed_nrows(), 'nrows' in expected_encodings)
def testWithPrecomputedSplits(self):
rp = RowPartition.from_row_splits([0, 2, 8])
rp_with_row_splits = rp.with_precomputed_row_splits()
self.assertTrue(rp_with_row_splits.has_precomputed_row_splits())
self.assertFalse(rp.has_precomputed_row_lengths())
rp_with_row_lengths = rp.with_precomputed_row_lengths()
self.assertTrue(rp_with_row_lengths.has_precomputed_row_lengths())
self.assertFalse(rp.has_precomputed_value_rowids())
rp_with_value_rowids = rp.with_precomputed_value_rowids()
self.assertTrue(rp_with_value_rowids.has_precomputed_value_rowids())
self.assertFalse(rp.has_precomputed_nrows())
rp_with_nrows = rp.with_precomputed_nrows()
self.assertTrue(rp_with_nrows.has_precomputed_nrows())
@parameterized.named_parameters([
dict(
testcase_name='FromRowSplitsAndRowSplits',
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
y=lambda: RowPartition.from_row_splits([0, 3, 8]),
expected_encodings=['row_splits']),
dict(
testcase_name='FromRowSplitsAndUniformRowLength',
x=lambda: RowPartition.from_row_splits([0, 3, 6]),
y=lambda: RowPartition.from_uniform_row_length(3, nvals=6),
expected_encodings=['row_splits', 'uniform_row_length', 'nrows']),
dict(
testcase_name='FromRowSplitsAndRowLengths',
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
y=lambda: RowPartition.from_row_lengths([3, 5]),
expected_encodings=['row_splits', 'row_lengths']),
dict(
testcase_name='FromRowSplitsAndValueRowIds',
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
y=lambda: RowPartition.from_value_rowids([0, 0, 0, 1, 1, 1, 1, 1]),
expected_encodings=[
'row_splits', 'row_lengths', 'value_rowids', 'nrows'
]),
dict(
testcase_name='FromRowSplitsAndRowSplitsPlusNRows',
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
y=lambda: RowPartition.from_row_splits([0, 3, 8]).
with_precomputed_nrows(),
expected_encodings=['row_splits', 'nrows']),
])
def testMergePrecomputedEncodings(self, x, y, expected_encodings):
x = x()
y = y()
for validate in (True, False):
result = x.merge_precomputed_encodings(y, validate)
self.assertEqual(result.has_precomputed_row_splits(),
'row_splits' in expected_encodings)
self.assertEqual(result.has_precomputed_row_lengths(),
'row_lengths' in expected_encodings)
self.assertEqual(result.has_precomputed_value_rowids(),
'value_rowids' in expected_encodings)
self.assertEqual(result.has_precomputed_nrows(),
'nrows' in expected_encodings)
self.assertEqual(result.uniform_row_length() is not None,
'uniform_row_length' in expected_encodings)
for r in (x, y):
if (r.has_precomputed_row_splits() and
result.has_precomputed_row_splits()):
self.assertAllEqual(r.row_splits(), result.row_splits())
if (r.has_precomputed_row_lengths() and
result.has_precomputed_row_lengths()):
self.assertAllEqual(r.row_lengths(), result.row_lengths())
if (r.has_precomputed_value_rowids() and
result.has_precomputed_value_rowids()):
self.assertAllEqual(r.value_rowids(), result.value_rowids())
if r.has_precomputed_nrows() and result.has_precomputed_nrows():
self.assertAllEqual(r.nrows(), result.nrows())
if (r.uniform_row_length() is not None and
result.uniform_row_length() is not None):
self.assertAllEqual(r.uniform_row_length(),
result.uniform_row_length())
def testMergePrecomputedEncodingsFastPaths(self):
# Same object: x gets returned as-is.
x = RowPartition.from_row_splits([0, 3, 8, 8])
self.assertIs(x.merge_precomputed_encodings(x), x)
# Same encoding tensor objects: x gets returned as-is.
y = RowPartition.from_row_splits(x.row_splits(), validate=False)
self.assertIs(x.merge_precomputed_encodings(y), x)
def testMergePrecomputedEncodingsWithMatchingTensors(self):
# The encoding tensors for `a` are a superset of the encoding tensors
# for `b`, and where they overlap, they the same tensor objects.
a = RowPartition.from_value_rowids([0, 0, 3, 4, 4, 4])
b = RowPartition.from_row_splits(a.row_splits(), validate=False)
self.assertIs(a.merge_precomputed_encodings(b), a)
self.assertIs(b.merge_precomputed_encodings(a), a)
self.assertIsNot(a, b)
@parameterized.named_parameters([
dict(
testcase_name='RowSplitMismatch',
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
y=lambda: RowPartition.from_row_splits([0, 3, 8, 9]),
message='incompatible row_splits'),
dict(
testcase_name='RowLengthMismatch',
x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
y=lambda: RowPartition.from_row_lengths([2, 0, 2, 1]),
message='incompatible row_splits'), # row_splits is checked first
dict(
testcase_name='ValueRowIdMismatch',
x=lambda: RowPartition.from_value_rowids([0, 3, 3, 4]),
y=lambda: RowPartition.from_value_rowids([0, 3, 4]),
message='incompatible value_rowids'),
])
def testMergePrecomputedEncodingStaticErrors(self, x, y, message):
if context.executing_eagerly():
return
# Errors that are caught by static shape checks.
x = x()
y = y()
with self.assertRaisesRegexp(ValueError, message):
x.merge_precomputed_encodings(y).row_splits()
with self.assertRaisesRegexp(ValueError, message):
y.merge_precomputed_encodings(x).row_splits()
@parameterized.named_parameters([
dict(
testcase_name='NRowsMismatch',
x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
y=lambda: RowPartition.from_uniform_row_length(5, nvals=15),
message='incompatible nrows'),
dict(
testcase_name='UniformRowLengthMismatch',
x=lambda: RowPartition.from_uniform_row_length(5, nvals=20),
y=lambda: RowPartition.from_uniform_row_length(2, nvals=8),
message='incompatible uniform_row_length'),
dict(
testcase_name='RowSplitMismatch',
x=lambda: RowPartition.from_row_splits([0, 3, 8]),
y=lambda: RowPartition.from_row_splits([0, 5, 8]),
message='incompatible row_splits'),
dict(
testcase_name='RowLengthMismatch',
x=lambda: RowPartition.from_row_lengths([2, 0, 2]),
y=lambda: RowPartition.from_row_lengths([0, 0, 2]),
message='incompatible row_splits'), # row_splits is checked first
dict(
testcase_name='ValueRowIdMismatch',
x=lambda: RowPartition.from_value_rowids([0, 3, 3]),
y=lambda: RowPartition.from_value_rowids([0, 0, 3]),
message='incompatible row_splits'), # row_splits is checked first
])
def testMergePrecomputedEncodingRuntimeErrors(self, x, y, message):
# Errors that are caught by runtime value checks.
x = x()
y = y()
with self.assertRaisesRegexp(errors.InvalidArgumentError, message):
self.evaluate(x.merge_precomputed_encodings(y).row_splits())
with self.assertRaisesRegexp(errors.InvalidArgumentError, message):
self.evaluate(y.merge_precomputed_encodings(x).row_splits())
@test_util.run_all_in_graph_and_eager_modes
class RowPartitionSpecTest(test_util.TensorFlowTestCase,