From 6c26c995db67a70d95974fea71103712a12128bc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 4 Mar 2020 12:48:00 -0800 Subject: [PATCH] Pull partition (row_splits, value_rowid, et cetera) of a RaggedTensor into a separate object. This allows one RowPartition to be shared by multiple RaggedTensor objects. Copied over tests from RaggedTensorTest to RowPartitionTest. PiperOrigin-RevId: 298916432 Change-Id: Ibf05c1b5969c0f4d3d21b301b5ee99e7ca1350e1 --- tensorflow/python/ops/ragged/BUILD | 56 +- .../python/ops/ragged/ragged_factory_ops.py | 2 +- .../ops/ragged/ragged_placeholder_op_test.py | 19 +- tensorflow/python/ops/ragged/ragged_tensor.py | 710 ++++++--------- .../python/ops/ragged/ragged_tensor_test.py | 315 ++++--- tensorflow/python/ops/ragged/row_partition.py | 843 ++++++++++++++++++ .../python/ops/ragged/row_partition_test.py | 559 ++++++++++++ .../golden/v1/tensorflow.-ragged-tensor.pbtxt | 6 +- .../golden/v2/tensorflow.-ragged-tensor.pbtxt | 6 +- 9 files changed, 1947 insertions(+), 569 deletions(-) create mode 100644 tensorflow/python/ops/ragged/row_partition.py create mode 100644 tensorflow/python/ops/ragged/row_partition_test.py diff --git a/tensorflow/python/ops/ragged/BUILD b/tensorflow/python/ops/ragged/BUILD index efef01de764..9504888d8fb 100644 --- a/tensorflow/python/ops/ragged/BUILD +++ b/tensorflow/python/ops/ragged/BUILD @@ -269,6 +269,27 @@ py_library( srcs_version = "PY2AND3", ) +py_library( + name = "row_partition", + srcs = ["row_partition.py"], + srcs_version = "PY2AND3", + deps = [ + ":segment_id_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python:tensor_util", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + py_library( name = "ragged_tensor", srcs = ["ragged_tensor.py"], @@ -277,7 +298,7 @@ py_library( ":ragged_config", ":ragged_tensor_value", ":ragged_util", - ":segment_id_ops", + ":row_partition", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:composite_tensor", @@ -292,6 +313,7 @@ py_library( "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_spec", "//tensorflow/python:tensor_util", + "//tensorflow/python:tf2", "//tensorflow/python:type_spec", "//tensorflow/python:util", "//third_party/py/numpy", @@ -449,12 +471,44 @@ py_test( ":ragged_tensor_value", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_spec", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//third_party/py/numpy", + "@absl_py//absl/testing:parameterized", + ], +) + +py_test( + name = "row_partition_test", + size = "medium", + timeout = "long", + srcs = ["row_partition_test.py"], + python_version = "PY3", + shard_count = 4, + srcs_version = "PY2AND3", + tags = [ + "no_windows", + ], + deps = [ + ":ragged", # fixdeps: keep + ":row_partition", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", + "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", diff --git a/tensorflow/python/ops/ragged/ragged_factory_ops.py b/tensorflow/python/ops/ragged/ragged_factory_ops.py index bb4337cc011..aa148ae7fe8 100644 --- a/tensorflow/python/ops/ragged/ragged_factory_ops.py +++ b/tensorflow/python/ops/ragged/ragged_factory_ops.py @@ -342,5 +342,5 @@ def placeholder(dtype, ragged_rank, value_shape=None, name=None): for i in reversed(range(ragged_rank)): row_splits = array_ops.placeholder(dtypes.int64, [None], "row_splits_%d" % i) - result = ragged_tensor.RaggedTensor(result, row_splits, internal=True) + result = ragged_tensor.RaggedTensor.from_row_splits(result, row_splits) return result diff --git a/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py b/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py index 0c5307031ac..d2261d408b3 100644 --- a/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py +++ b/tensorflow/python/ops/ragged/ragged_placeholder_op_test.py @@ -37,22 +37,27 @@ class RaggedPlaceholderOpTest(test_util.TensorFlowTestCase, (dtypes.int32, 1, [], 'ph', 'tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None,), dtype=int32), ' - 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), + 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' + 'shape=(None,), dtype=int64))'), (dtypes.string, 1, [5], 'ph', 'tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None, 5), dtype=string), ' - 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), + 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' + 'shape=(None,), dtype=int64))'), (dtypes.float32, 2, [], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None,), dtype=float32), ' - 'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), ' - 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), + 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' + 'shape=(None,), dtype=int64)), ' + 'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", ' + 'shape=(None,), dtype=int64))'), (dtypes.int32, 2, [3, 5], 'ph', 'tf.RaggedTensor(values=tf.RaggedTensor(' 'values=Tensor("ph/flat_values:0", shape=(None, 3, 5), dtype=int32), ' - 'row_splits=Tensor("ph/row_splits_1:0", shape=(None,), dtype=int64)), ' - 'row_splits=Tensor("ph/row_splits_0:0", shape=(None,), dtype=int64))'), - + 'row_splits=Tensor("ph/RaggedFromRowSplits/control_dependency:0", ' + 'shape=(None,), dtype=int64)), ' + 'row_splits=Tensor("ph/RaggedFromRowSplits_1/control_dependency:0", ' + 'shape=(None,), dtype=int64))'), ]) def testRaggedPlaceholder(self, dtype, ragged_rank, value_shape, name, expected): diff --git a/tensorflow/python/ops/ragged/ragged_tensor.py b/tensorflow/python/ops/ragged/ragged_tensor.py index 9ecb3be190c..06fddab18f0 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor.py +++ b/tensorflow/python/ops/ragged/ragged_tensor.py @@ -42,11 +42,12 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops.ragged import ragged_config from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged import ragged_util -from tensorflow.python.ops.ragged import segment_id_ops +from tensorflow.python.ops.ragged.row_partition import RowPartition from tensorflow.python.util.tf_export import tf_export # pylint: disable=protected-access _eval_using_default_session = ops._eval_using_default_session +_convert_row_partition = RowPartition._convert_row_partition # pylint: enable=protected-access #=============================================================================== @@ -112,7 +113,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): ### Alternative Row-Partitioning Schemes - In addition to `row_splits`, ragged tensors provide support for four other + In addition to `row_splits`, ragged tensors provide support for five other row-partitioning schemes: * `row_lengths`: a vector with shape `[nrows]`, which specifies the length @@ -226,14 +227,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): #============================================================================= # Constructor (private) #============================================================================= - def __init__(self, - values, - row_splits, - cached_row_lengths=None, - cached_value_rowids=None, - cached_nrows=None, - internal=False, - uniform_row_length=None): + def __init__(self, values, row_partition, internal=False): """Creates a `RaggedTensor` with a specified partitioning for `values`. This constructor is private -- please use one of the following ops to @@ -250,70 +244,86 @@ class RaggedTensor(composite_tensor.CompositeTensor): Args: values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`. - row_splits: A 1-D integer tensor with shape `[nrows+1]`. - cached_row_lengths: A 1-D integer tensor with shape `[nrows]` - cached_value_rowids: A 1-D integer tensor with shape `[nvals]`. - cached_nrows: A 1-D integer scalar tensor. + row_partition: A `RowPartition` object, representing the arrangement of + the lists at the top level. internal: True if the constructor is being called by one of the factory methods. If false, an exception will be raised. - uniform_row_length: A scalar tensor. Raises: - TypeError: If a row partitioning tensor has an inappropriate dtype. - TypeError: If exactly one row partitioning argument was not specified. - ValueError: If a row partitioning tensor has an inappropriate shape. - ValueError: If multiple partitioning arguments are specified. - ValueError: If nrows is specified but value_rowids is not None. + ValueError: If internal = False. Note that this method is intended only + for internal use. + TypeError: If values is not a `RaggedTensor` or `Tensor`, or + row_partition is not a `RowPartition`. """ + if not internal: raise ValueError("RaggedTensor constructor is private; please use one " "of the factory methods instead (e.g., " "RaggedTensor.from_row_lengths())") - - # Validate the arguments. - if not isinstance(row_splits, ops.Tensor): - raise TypeError("Row-partitioning argument must be a Tensor, got %r" % - row_splits) if not isinstance(values, (RaggedTensor, ops.Tensor)): raise TypeError("values must be a Tensor or RaggedTensor, got %r" % values) - if row_splits.dtype not in (dtypes.int32, dtypes.int64): - raise ValueError("Row-partitioning argument must be int32 or int64") + if not isinstance(row_partition, RowPartition): + raise TypeError("row_partition must be a RowPartition, got %r" % + row_partition) - # Validate shapes & dtypes. - row_splits.shape.assert_has_rank(1) + # Validate shapes. + values = convert_to_tensor_or_ragged_tensor(values) values.shape.with_rank_at_least(1) - row_splits.set_shape([None]) if isinstance(values, RaggedTensor): - assert row_splits.dtype == values.row_splits.dtype + assert row_partition.dtype == values.row_partition.dtype self._values = values - self._row_splits = row_splits - - # Store any cached tensors. These are used to avoid unnecessary - # round-trip conversions when a RaggedTensor is constructed from - # lengths or rowids, and we later want those lengths/rowids back. - for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]: - if tensor is not None: - if not isinstance(tensor, ops.Tensor): - raise TypeError("Cached value must be a Tensor or None.") - elif tensor.dtype not in (dtypes.int32, dtypes.int64): - raise TypeError("Cached value must be int32 or int64.") - self._cached_row_lengths = cached_row_lengths - self._cached_value_rowids = cached_value_rowids - self._cached_nrows = cached_nrows - - if uniform_row_length is not None: - if not isinstance(uniform_row_length, ops.Tensor): - raise TypeError("uniform_row_length must be a Tensor or None.") - elif uniform_row_length.dtype not in (dtypes.int32, dtypes.int64): - raise TypeError("uniform_row_length must be int32 or int64.") - self._uniform_row_length = uniform_row_length + self._row_partition = row_partition #============================================================================= # Factory Methods #============================================================================= + @classmethod + def _from_row_partition(cls, values, row_partition, validate=True): + """Creates a `RaggedTensor` with a row partition. + + This is used as a way for RaggedTensors to share row partitions. + + The outer dimension of values must be equal to `partition.nvals()`. + + Args: + values: A potentially ragged tensor. + row_partition: a `RowPartition`: can be shared between tensors. + validate: If true, then use assertions to check that the arguments form a + valid `RaggedTensor`. + + Returns: + A `RaggedTensor`. `result.rank = values.rank + 1`. + `result.ragged_rank = values.ragged_rank + 1`. + + Raises: + ValueError: If partition.nvals() != _nrows(values) + """ + if not isinstance(row_partition, RowPartition): + raise TypeError("row_partition must be a RowPartition") + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + values, row_partition = cls._convert_values_and_partition( + values, row_partition, "partition") + if validate: + msg = "Arguments to _from_row_partition do not form a valid RaggedTensor" + nvals = _nrows(values, row_partition.dtype) + checks = [ + check_ops.assert_equal( + row_partition.nvals(out_type=row_partition.dtype), + nvals, + message=msg), + ] + if not isinstance(values, RaggedTensor): + checks.append(check_ops.assert_rank_at_least(values, 1)) + row_partition = row_partition.with_dependencies(checks) + return cls( + values=values, + internal=True, + row_partition=row_partition) + @classmethod def from_value_rowids(cls, values, @@ -362,76 +372,16 @@ class RaggedTensor(composite_tensor.CompositeTensor): """ if not isinstance(validate, bool): raise TypeError("validate must have type bool") + with ops.name_scope(name, "RaggedFromValueRowIds", [values, value_rowids, nrows]): - values, value_rowids = cls._convert_values_and_row_partition( - values, value_rowids, "value_rowids") - if nrows is None: - const_rowids = tensor_util.constant_value(value_rowids) - if const_rowids is None: - nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1 - const_nrows = None - else: - const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 - nrows = ops.convert_to_tensor(const_nrows, value_rowids.dtype, - name="nrows") - else: - nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows") - const_nrows = tensor_util.constant_value(nrows) - if const_nrows is not None: - if const_nrows < 0: - raise ValueError("Expected nrows >= 0; got %d" % const_nrows) - const_rowids = tensor_util.constant_value(value_rowids) - if const_rowids is not None and const_rowids.size > 0: - if not const_nrows >= const_rowids[-1] + 1: - raise ValueError( - "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, " - "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1])) - - value_rowids.shape.assert_has_rank(1) - nrows.shape.assert_has_rank(0) - values.shape[:1].assert_is_compatible_with(value_rowids.shape) - - if validate: - msg = "Arguments to from_value_rowids do not form a valid RaggedTensor" - nvals1 = _nrows(values) - nvals2 = _nrows(value_rowids) - checks = [ - check_ops.assert_rank(value_rowids, 1, message=msg), - check_ops.assert_rank(nrows, 0, message=msg), - check_ops.assert_equal(nvals1, nvals2, message=msg), - check_ops.assert_non_negative(value_rowids[:1], message=msg), - _assert_monotonic_increasing(value_rowids, message=msg), - check_ops.assert_less(value_rowids[-1:], nrows, message=msg), - ] - if not isinstance(values, RaggedTensor): - checks.append(check_ops.assert_rank_at_least(values, 1)) - value_rowids = control_flow_ops.with_dependencies(checks, value_rowids) - - # Convert value_rowids & nrows to row_splits. - # Note: we don't use segment_ids_to_row_splits() here because we want - # to save the intermediate value `row_lengths`, so we can cache it. - # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the - # cast. - value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) - nrows_int32 = math_ops.cast(nrows, dtypes.int32) - row_lengths = math_ops.bincount( - value_rowids_int32, - minlength=nrows_int32, - maxlength=nrows_int32, - dtype=value_rowids.dtype) - row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) - if const_nrows is not None: - row_lengths.set_shape([const_nrows]) - row_splits.set_shape([const_nrows + 1]) - - return cls( - values, - row_splits, - cached_row_lengths=row_lengths, - cached_value_rowids=value_rowids, - cached_nrows=nrows, - internal=True) + row_partition = RowPartition.from_value_rowids( + value_rowids, + nrows=nrows, + name=name, + validate=validate, + preferred_dtype=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) @classmethod def from_row_splits(cls, values, row_splits, name=None, validate=True): @@ -471,30 +421,14 @@ class RaggedTensor(composite_tensor.CompositeTensor): """ if not isinstance(validate, bool): raise TypeError("validate must have type bool") - if isinstance(row_splits, (list, tuple)) and not row_splits: - raise ValueError("row_splits tensor may not be empty.") - if isinstance(row_splits, tensor_spec.TensorSpec): - return cls(values=values, row_splits=row_splits, internal=True) with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]): - values, row_splits = cls._convert_values_and_row_partition( - values, row_splits, "row_splits") - row_splits.shape.assert_has_rank(1) - - if validate: - msg = "Arguments to from_row_splits do not form a valid RaggedTensor" - nvals = _nrows(values, row_splits.dtype) - checks = [ - check_ops.assert_rank(row_splits, 1, message=msg), - _assert_zero(row_splits[0], message=msg), - _assert_monotonic_increasing(row_splits, message=msg), - check_ops.assert_equal(row_splits[-1], nvals, message=msg), - ] - if not isinstance(values, RaggedTensor): - checks.append(check_ops.assert_rank_at_least(values, 1)) - row_splits = control_flow_ops.with_dependencies(checks, row_splits) - - return cls(values=values, row_splits=row_splits, internal=True) + row_partition = RowPartition.from_row_splits( + row_splits, + name=name, + validate=validate, + preferred_dtype=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) @classmethod def from_row_lengths(cls, values, row_lengths, name=None, validate=True): @@ -530,31 +464,14 @@ class RaggedTensor(composite_tensor.CompositeTensor): """ if not isinstance(validate, bool): raise TypeError("validate must have type bool") + with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]): - values, row_lengths = cls._convert_values_and_row_partition( - values, row_lengths, "row_lengths") - row_lengths.shape.assert_has_rank(1) - - if validate: - msg = "Arguments to from_row_lengths do not form a valid RaggedTensor" - nvals1 = math_ops.reduce_sum(row_lengths) - nvals2 = _nrows(values, row_lengths.dtype) - checks = [ - check_ops.assert_rank(row_lengths, 1, message=msg), - check_ops.assert_non_negative(row_lengths, message=msg), - check_ops.assert_equal(nvals1, nvals2, message=msg) - ] - if not isinstance(values, RaggedTensor): - checks.append(check_ops.assert_rank_at_least(values, 1)) - row_lengths = control_flow_ops.with_dependencies(checks, row_lengths) - - row_limits = math_ops.cumsum(row_lengths) - row_splits = array_ops.concat([[0], row_limits], axis=0) - return cls( - values=values, - row_splits=row_splits, - cached_row_lengths=row_lengths, - internal=True) + row_partition = RowPartition.from_row_lengths( + row_lengths, + name=name, + validate=validate, + preferred_dtype=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) @classmethod def from_row_starts(cls, values, row_starts, name=None, validate=True): @@ -587,25 +504,14 @@ class RaggedTensor(composite_tensor.CompositeTensor): if not isinstance(validate, bool): raise TypeError("validate must have type bool") with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]): - values, row_starts = cls._convert_values_and_row_partition( - values, row_starts, "row_starts") - row_starts.shape.assert_has_rank(1) - nvals = _nrows(values, row_starts.dtype) - - if validate: - msg = "Arguments to from_row_starts do not form a valid RaggedTensor" - checks = [ - check_ops.assert_rank(row_starts, 1, message=msg), - _assert_zero(row_starts[:1], message=msg), - _assert_monotonic_increasing(row_starts, message=msg), - check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg), - ] - if not isinstance(values, RaggedTensor): - checks.append(check_ops.assert_rank_at_least(values, 1)) - row_starts = control_flow_ops.with_dependencies(checks, row_starts) - - row_splits = array_ops.concat([row_starts, [nvals]], axis=0) - return cls(values=values, row_splits=row_splits, internal=True) + values = convert_to_tensor_or_ragged_tensor(values) + row_partition = RowPartition.from_row_starts( + row_starts, + _nrows(values), + name=name, + validate=validate, + preferred_dtype=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) @classmethod def from_row_limits(cls, values, row_limits, name=None, validate=True): @@ -637,26 +543,12 @@ class RaggedTensor(composite_tensor.CompositeTensor): if not isinstance(validate, bool): raise TypeError("validate must have type bool") with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]): - values, row_limits = cls._convert_values_and_row_partition( - values, row_limits, "row_limits") - row_limits.shape.assert_has_rank(1) - - if validate: - msg = "Arguments to from_row_limits do not form a valid RaggedTensor" - nvals = _nrows(values, row_limits.dtype) - checks = [ - check_ops.assert_rank(row_limits, 1, message=msg), - check_ops.assert_non_negative(row_limits[:1], message=msg), - _assert_monotonic_increasing(row_limits, message=msg), - check_ops.assert_equal(row_limits[-1:], nvals, message=msg) - ] - if not isinstance(values, RaggedTensor): - checks.append(check_ops.assert_rank_at_least(values, 1)) - row_limits = control_flow_ops.with_dependencies(checks, row_limits) - - zero = array_ops.zeros([1], row_limits.dtype) - row_splits = array_ops.concat([zero, row_limits], axis=0) - return cls(values=values, row_splits=row_splits, internal=True) + row_partition = RowPartition.from_row_limits( + row_limits, + name=name, + validate=validate, + preferred_dtype=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) @classmethod def from_uniform_row_length(cls, @@ -691,8 +583,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): Args: values: A potentially ragged tensor with shape `[nvals, ...]`. - uniform_row_length: A scalar integer tensor. Must be nonnegative. - The size of the outer axis of `values` must be evenly divisible by + uniform_row_length: A scalar integer tensor. Must be nonnegative. The + size of the outer axis of `values` must be evenly divisible by `uniform_row_length`. nrows: The number of rows in the constructed RaggedTensor. If not specified, then it defaults to `nvals/uniform_row_length` (or `0` if @@ -719,85 +611,19 @@ class RaggedTensor(composite_tensor.CompositeTensor): raise TypeError("validate must have type bool") with ops.name_scope(name, "RaggedFromUniformRowLength", [values, uniform_row_length, nrows]): - values, uniform_row_length = cls._convert_values_and_row_partition( - values, uniform_row_length, "uniform_row_length") - uniform_row_length.shape.assert_has_rank(0) - - # Find nvals. - const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value - if const_nvals is not None: - nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) - elif isinstance(values, RaggedTensor): - nvals = values.nrows(out_type=uniform_row_length.dtype) - else: - nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0] - - # Find nrows. - const_row_length = tensor_util.constant_value(uniform_row_length) - if nrows is None: - if const_row_length is None: - # Avoid division by zero if uniform_row_length==0 (and nvals==0). - rowlen_or_1 = control_flow_ops.cond( - math_ops.equal(uniform_row_length, 0), - lambda: constant_op.constant(1, uniform_row_length.dtype), - lambda: uniform_row_length) - nrows = nvals // rowlen_or_1 - elif const_row_length == 0: - nrows = 0 - else: - nrows = nvals // const_row_length - nrows = ops.convert_to_tensor( - nrows, uniform_row_length.dtype, name="nrows") - const_nrows = tensor_util.constant_value(nrows) - - # Find row_splits. - if const_nrows is not None and const_row_length is not None: - row_splits = [v * const_row_length for v in range(const_nrows + 1)] - row_splits = constant_op.constant(row_splits, uniform_row_length.dtype) - else: - row_splits = math_ops.range(nrows + 1) * uniform_row_length - - if validate: - checks = [] - - if (const_nrows is None or const_row_length is None or - const_nvals is None): - checks.append(check_ops.assert_equal( - nrows * uniform_row_length, - nvals, - ("uniform_row_length", uniform_row_length, "times nrows", - nrows, "must equal nvals", nvals))) - else: - if const_nrows * const_row_length != const_nvals: - raise ValueError( - "uniform_row_length=%d times nrows=%d must equal nvals=%d" - % (const_row_length, const_nrows, const_nvals)) - - if uniform_row_length.shape.rank is None: - checks.append( - check_ops.assert_rank( - uniform_row_length, 0, - message="uniform_row_length must be a scalar.")) - - const_row_length = tensor_util.constant_value(uniform_row_length) - if const_row_length is None: - checks.append( - check_ops.assert_greater_equal( - uniform_row_length, - constant_op.constant(0, uniform_row_length.dtype), - message="uniform_row_length must be >= 0.")) - else: - if const_row_length < 0: - raise ValueError("uniform_row_length must be >= 0.") - - row_splits = control_flow_ops.with_dependencies(checks, row_splits) - - return cls( - values=values, - row_splits=row_splits, - uniform_row_length=uniform_row_length, - cached_nrows=nrows, - internal=True) + values = convert_to_tensor_or_ragged_tensor(values) + uniform_row_length = _convert_row_partition( + uniform_row_length, "UniformRowLength", + _get_optional_partition_dtype(values)) + nvals = _nvals_uniform_row_length(values, uniform_row_length) + row_partition = RowPartition.from_uniform_row_length( + nvals, + uniform_row_length, + nrows=nrows, + name=name, + validate=validate, + preferred_dtype=_get_optional_partition_dtype(values)) + return cls._from_row_partition(values, row_partition, validate=validate) @classmethod def from_nested_value_rowids(cls, @@ -823,7 +649,6 @@ class RaggedTensor(composite_tensor.CompositeTensor): nested_nrows: A list of integer scalars. The `i`th scalar is used as the `nrows` for the `i`th ragged dimension. name: A name prefix for the RaggedTensor (optional). - validate: If true, then use assertions to check that the arguments form a valid `RaggedTensor`. Note: these assertions incur a runtime cost, since they must be checked for each tensor value. @@ -847,14 +672,13 @@ class RaggedTensor(composite_tensor.CompositeTensor): raise ValueError("nested_nrows must have the same length as " "nested_value_rowids") - with ops.name_scope( - name, "RaggedFromNestedValueRowIds", - [flat_values] + list(nested_value_rowids) + list(nested_nrows)): + with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] + + list(nested_value_rowids) + list(nested_nrows)): result = flat_values for value_rowids, nrows in reversed( list(zip(nested_value_rowids, nested_nrows))): - result = cls.from_value_rowids(result, value_rowids, nrows, - validate=validate) + result = cls.from_value_rowids( + result, value_rowids, nrows, validate=validate) return result @classmethod @@ -936,7 +760,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): return result @classmethod - def _convert_values_and_row_partition(cls, values, partition, name): + def _convert_values_and_partition(cls, values, row_partition, name): """Converts `values` and `partition` to Tensors. If `values` is a `RaggedTensor`, then converts `values` and `partition` @@ -948,39 +772,26 @@ class RaggedTensor(composite_tensor.CompositeTensor): Args: values: The `values` for the `RaggedTensor` being constructed. - partition: A row-partitioning tensor for the `RaggedTensor` being - constructed. I.e., one of: row_splits, row_lengths, row_starts, - row_limits, value_rowids. - name: The name of the row-partitioning tensor. + row_partition: A RowPartition object for the `RaggedTensor` being + constructed. + name: The name of the RowPartition object. Returns: A tuple (values, partition). """ + if not isinstance(row_partition, RowPartition): + raise ValueError("partition must be a RowPartition") if isinstance(values, RaggedTensor): - if isinstance(partition, ops.Tensor): - if partition.dtype not in (dtypes.int32, dtypes.int64): - raise ValueError("%s must have dtype int32 or int64" % name) - if values.row_splits.dtype != partition.dtype: - if not ragged_config.auto_cast_partition_dtype(): - raise ValueError("dtype mismatch: %s (%s) vs values.row_splits (%s)" - % (name, partition.dtype, values.row_splits.dtype)) - partition = math_ops.cast(partition, dtypes.int64) - values = values.with_row_splits_dtype(dtypes.int64) - else: - partition = ops.convert_to_tensor(partition, values.row_splits.dtype, - name=name) + if values.row_partition.dtype != row_partition.dtype: + if not ragged_config.auto_cast_partition_dtype(): + raise ValueError( + "dtype mismatch: %s (%s) vs values.partition (%s)" % + (name, row_partition.dtype, values.row_partition.dtype)) + values = values.with_row_splits_dtype(row_partition.dtype) else: values = ops.convert_to_tensor(values, name="values") - if isinstance(partition, np.ndarray) and partition.dtype == np.int32: - partition = ops.convert_to_tensor(partition, name=name) - else: - partition = ops.convert_to_tensor( - partition, preferred_dtype=dtypes.int64, - name=name) - if partition.dtype not in (dtypes.int32, dtypes.int64): - raise ValueError("%s must have dtype int32 or int64" % name) - return (values, partition) + return (values, row_partition) #============================================================================= # Accessors @@ -1008,10 +819,11 @@ class RaggedTensor(composite_tensor.CompositeTensor): TensorShape([2, None, 2]) """ - nrows = tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1 - if self._uniform_row_length is not None: - row_length = tensor_util.constant_value(self._uniform_row_length) + nrows = self._row_partition.nrows_as_dimension() + if self._row_partition.uniform_row_length() is not None: + row_length = tensor_util.constant_value( + self._row_partition.uniform_row_length()) else: row_length = None @@ -1020,6 +832,11 @@ class RaggedTensor(composite_tensor.CompositeTensor): return tensor_shape.TensorShape([nrows, row_length]).concatenate(value_shape) + @property + def row_partition(self): + """The row partition of the top level of the ragged tensor.""" + return self._row_partition + @property def ragged_rank(self): """The number of ragged dimensions in this ragged tensor. @@ -1076,7 +893,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64) """ - return self._row_splits + return self._row_partition.row_splits @property def uniform_row_length(self): @@ -1102,7 +919,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): ragged tensor (for ragged tensors whose rows are uniform); or `None` (for ragged tensors whose rows are ragged). """ - return self._uniform_row_length + return self._row_partition.uniform_row_length() @property def flat_values(self): @@ -1189,11 +1006,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64) """ - if self._cached_value_rowids is not None: - return self._cached_value_rowids - - with ops.name_scope(name, "RaggedValueRowIds", [self]): - return segment_id_ops.row_splits_to_segment_ids(self.row_splits) + return self._row_partition.value_rowids(name=name) def nested_value_rowids(self, name=None): """Returns a tuple containing the value_rowids for all ragged dimensions. @@ -1252,18 +1065,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): tf.Tensor(5, shape=(), dtype=int64) """ - if out_type is None: - out_type = self._row_splits.dtype - else: - out_type = dtypes.as_dtype(out_type) - if self._cached_nrows is not None: - return math_ops.cast(self._cached_nrows, out_type) - with ops.name_scope(name, "RaggedNRows", [self]): - nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0) - if nsplits.value is None: - return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1 - else: - return constant_op.constant(nsplits.value - 1, dtype=out_type) + return self._row_partition.nrows(out_type=out_type, name=name) def row_starts(self, name=None): """Returns the start indices for rows in this ragged tensor. @@ -1287,8 +1089,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64) """ - with ops.name_scope(name, "RaggedRowStarts", [self]): - return self.row_splits[:-1] + return self._row_partition.row_starts(name=name) def row_limits(self, name=None): """Returns the limit indices for rows in this ragged tensor. @@ -1312,8 +1113,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64) """ - with ops.name_scope(name, "RaggedRowLimits", [self]): - return self.row_splits[1:] + return self._row_partition.row_limits(name=name) def row_lengths(self, axis=1, name=None): """Returns the lengths of the rows in this ragged tensor. @@ -1342,8 +1142,11 @@ class RaggedTensor(composite_tensor.CompositeTensor): """ - if self._cached_row_lengths is not None: - return self._cached_row_lengths + if axis == 0: + return self._row_partition.nrows() + + if axis == 1: + return self._row_partition.row_lengths() with ops.name_scope(name, "RaggedRowLengths", [self]): axis = array_ops.get_positive_axis( @@ -1356,9 +1159,9 @@ class RaggedTensor(composite_tensor.CompositeTensor): elif isinstance(self.values, RaggedTensor): return self.with_values(self.values.row_lengths(axis - 1)) else: - shape = array_ops.shape(self.values, out_type=self._row_splits.dtype) + shape = array_ops.shape(self.values, out_type=self.row_partition.dtype) return self.with_values( - array_ops.ones(shape[:axis - 1], self._row_splits.dtype) * + array_ops.ones(shape[:axis - 1], self.row_partition.dtype) * shape[axis - 1]) def nested_row_lengths(self, name=None): @@ -1408,7 +1211,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): """ if out_type is None: - out_type = self._row_splits.dtype + out_type = self.row_partition.dtype else: out_type = dtypes.as_dtype(out_type) with ops.name_scope(name, "RaggedBoundingBox", [self, axis]): @@ -1456,7 +1259,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): new_values.shape.with_rank_at_least(1) self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1]) if (isinstance(new_values, RaggedTensor) and - self._row_splits.dtype != new_values.row_splits.dtype): + self.row_partition.row_splits.dtype != new_values.row_splits.dtype): if not ragged_config.auto_cast_partition_dtype(): raise ValueError("self and new_values have mismatched row_splits " "dtypes; use RaggedTensor.with_row_splits_dtype() to " @@ -1464,13 +1267,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): new_values = new_values.with_row_splits_dtype(dtypes.int64) return self.with_row_splits_dtype(dtypes.int64).with_values(new_values) return RaggedTensor( - values=new_values, - row_splits=self._row_splits, - cached_row_lengths=self._cached_row_lengths, - cached_value_rowids=self._cached_value_rowids, - cached_nrows=self._cached_nrows, - internal=True, - uniform_row_length=self._uniform_row_length) + values=new_values, row_partition=self._row_partition, internal=True) def with_flat_values(self, new_values): """Returns a copy of `self` with `flat_values` replaced by `new_value`. @@ -1480,8 +1277,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): Args: new_values: Potentially ragged tensor that should replace - `self.flat_values`. Must have `rank > 0`, and must have the same - number of rows as `self.flat_values`. + `self.flat_values`. Must have `rank > 0`, and must have the same number + of rows as `self.flat_values`. Returns: A `RaggedTensor`. @@ -1509,30 +1306,19 @@ class RaggedTensor(composite_tensor.CompositeTensor): dtype = dtypes.as_dtype(dtype) if dtype not in (dtypes.int32, dtypes.int64): raise ValueError("dtype must be int32 or int64") - if self._row_splits.dtype == dtype: + if self._row_partition.dtype == dtype: return self - - row_splits = math_ops.cast(self._row_splits, dtype) - - values = self._values - if isinstance(values, RaggedTensor): - values = values.with_row_splits_dtype(dtype) - cached_row_lengths = self._cached_row_lengths - if cached_row_lengths is not None: - cached_row_lengths = math_ops.cast(cached_row_lengths, dtype) - cached_value_rowids = self._cached_value_rowids - if cached_value_rowids is not None: - cached_value_rowids = math_ops.cast(cached_value_rowids, dtype) - cached_nrows = self._cached_nrows - if cached_value_rowids is not None: - cached_value_rowids = math_ops.cast(cached_value_rowids, dtype) - uniform_row_length = self._uniform_row_length - if uniform_row_length is not None: - uniform_row_length = math_ops.cast(uniform_row_length, dtype) - - return RaggedTensor(values, row_splits, cached_row_lengths, - cached_value_rowids, cached_nrows, internal=True, - uniform_row_length=uniform_row_length) + current_values = self._values + if isinstance(current_values, RaggedTensor): + return RaggedTensor( + values=current_values.with_row_splits_dtype(dtype), + row_partition=self._row_partition.with_row_splits_dtype(dtype), + internal=True) + else: + return RaggedTensor( + values=current_values, + row_partition=self._row_partition.with_row_splits_dtype(dtype), + internal=True) def merge_dims(self, outer_axis, inner_axis): """Merges outer_axis...inner_axis into a single dimension. @@ -1558,8 +1344,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): Args: outer_axis: `int`: The first dimension in the range of dimensions to merge. May be negative if `self.shape.rank` is statically known. - inner_axis: `int`: The last dimension in the range of dimensions to - merge. May be negative if `self.shape.rank` is statically known. + inner_axis: `int`: The last dimension in the range of dimensions to merge. + May be negative if `self.shape.rank` is statically known. Returns: A copy of this tensor, with the specified dimensions merged into a @@ -1630,9 +1416,9 @@ class RaggedTensor(composite_tensor.CompositeTensor): `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows in `tensor`). If specified, then `output[row]` will contain `tensor[row][:lengths[row]]`. Negative lengths are treated as zero. You - may optionally pass a list or tuple of lengths to this argument, which - will be used as nested row lengths to construct a ragged tensor with - multiple ragged dimensions. + may optionally pass a list or tuple of lengths to this argument, which + will be used as nested row lengths to construct a ragged tensor with + multiple ragged dimensions. padding: An optional padding value. If specified, then any row suffix consisting entirely of `padding` will be excluded from the returned RaggedTensor. `padding` is a `Tensor` with the same dtype as `tensor` @@ -1655,8 +1441,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): if not isinstance(ragged_rank, int): raise TypeError("ragged_rank expected int, got %r" % ragged_rank) if ragged_rank <= 0: - raise ValueError( - "ragged_rank must be greater than 0; got %s" % ragged_rank) + raise ValueError("ragged_rank must be greater than 0; got %s" % + ragged_rank) with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]): tensor = ops.convert_to_tensor(tensor, name="tensor") @@ -1685,8 +1471,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): ones_mask, lengths, validate=False) dense_ragged_mask = ragged_mask.to_tensor(default_value=False) masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask) - return cls.from_nested_row_lengths( - masked_data, lengths, validate=False) + return cls.from_nested_row_lengths(masked_data, lengths, validate=False) # Handle ragged_rank>1 via recursion: # If the output should have multiple ragged dimensions, then first @@ -1706,8 +1491,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): axis=0) dim_size = math_ops.cumprod(input_shape) flattened = array_ops.reshape(tensor, new_shape) - result = cls.from_tensor(flattened, lengths, padding, - row_splits_dtype=row_splits_dtype) + result = cls.from_tensor( + flattened, lengths, padding, row_splits_dtype=row_splits_dtype) for axis in range(ragged_rank - 1, 0, -1): dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value @@ -1736,7 +1521,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): # If the padding isn't a scalar, then require that all values in the # padding match each item in the tensor. After this block of code, # `has_default.shape = tensor.shape[:2]`. (Unfortunately, we can't just - # use reduce_all for both cases, because when you pass an empty `axis` + # use reduce_all for both cases, becaue when you pass an empty `axis` # list to reduce_all, it reduces all axes; but we want it to reduce no # axes -- i.e., to be a no-op.) tensor_rank = array_ops.rank(tensor) @@ -1755,8 +1540,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): has_nondefault = math_ops.logical_not(has_default) has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype) length_for_nondefault_value = ( - has_nondefault * array_ops.expand_dims( - math_ops.range(1, ncols + 1), 0)) + has_nondefault * + array_ops.expand_dims(math_ops.range(1, ncols + 1), 0)) lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1) if lengths is not None: @@ -1791,8 +1576,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): else: ncols = input_shape[1] return RaggedTensor.from_uniform_row_length( - values=values, uniform_row_length=ncols, - nrows=nrows, validate=False) + values=values, uniform_row_length=ncols, nrows=nrows, validate=False) def to_tensor(self, default_value=None, name=None, shape=None): """Converts this `RaggedTensor` into a `tf.Tensor`. @@ -1968,10 +1752,10 @@ class RaggedTensor(composite_tensor.CompositeTensor): nested-batched) `RaggedTensor`. dtype: The dtype of the encoded `RaggedTensor`. output_ragged_rank: The expected ragged rank of the output `RaggedTensor`. - input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This - is optional and inferred dynamically if not provided. - row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. - One of `tf.int32` or `tf.int64`. + input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is + optional and inferred dynamically if not provided. + row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One + of `tf.int32` or `tf.int64`. name: A name prefix for the returned tensors (optional). Returns: @@ -2038,8 +1822,8 @@ class RaggedTensor(composite_tensor.CompositeTensor): if self._is_eager(): return "" % self.to_list() else: - return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self._values, - self._row_splits) + return "tf.RaggedTensor(values=%s, row_splits=%s)" % ( + self._values, self.row_partition.row_splits) #============================================================================= # Eager Execution Mode @@ -2075,7 +1859,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): if not self._is_eager(): raise ValueError("RaggedTensor.numpy() is only supported in eager mode.") values = self._values.numpy() - splits = self._row_splits.numpy() + splits = self._row_partition.row_splits.numpy() rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] if not rows: return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype) @@ -2152,7 +1936,7 @@ class RaggedTensor(composite_tensor.CompositeTensor): shape=self.shape, dtype=self.dtype, ragged_rank=self.ragged_rank, - row_splits_dtype=self._row_splits.dtype) + row_splits_dtype=self.row_partition.row_splits.dtype) def _shape_invariant_to_type_spec(self, shape): return RaggedTensorSpec(shape, self.dtype, self.ragged_rank, @@ -2176,6 +1960,7 @@ def match_row_splits_dtypes(*tensors, **kwargs): **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors), where `dtype` is the data type used by row-splits, and `tensors` is the converted list of `Tensors` and `RaggedTensors`. + Returns: The converted list of `Tensors` and `RaggedTensors`. """ @@ -2198,8 +1983,10 @@ def match_row_splits_dtypes(*tensors, **kwargs): "use RaggedTensor.with_row_splits_dtype() to convert " "them to compatible dtypes.") dtype = dtypes.int64 - tensors = tuple(t.with_row_splits_dtype(dtypes.int64) - if isinstance(t, RaggedTensor) else t for t in tensors) + tensors = tuple( + t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor + ) else t + for t in tensors) elif has_int32: dtype = dtypes.int32 @@ -2225,18 +2012,21 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): def value_type(self): return RaggedTensor if self._ragged_rank > 0 else ops.Tensor - def __init__(self, shape=None, dtype=dtypes.float32, ragged_rank=None, + def __init__(self, + shape=None, + dtype=dtypes.float32, + ragged_rank=None, row_splits_dtype=dtypes.int64): """Constructs a type specification for a `tf.RaggedTensor`. Args: - shape: The shape of the RaggedTensor, or `None` to allow any shape. If - a shape is specified, then all ragged dimensions must have size `None`. + shape: The shape of the RaggedTensor, or `None` to allow any shape. If a + shape is specified, then all ragged dimensions must have size `None`. dtype: `tf.DType` of values in the RaggedTensor. - ragged_rank: Python integer, the ragged rank of the RaggedTensor - to be described. Defaults to `shape.ndims - 1`. - row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. - One of `tf.int32` or `tf.int64`. + ragged_rank: Python integer, the ragged rank of the RaggedTensor to be + described. Defaults to `shape.ndims - 1`. + row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One + of `tf.int32` or `tf.int64`. """ self._shape = tensor_shape.as_shape(shape) self._dtype = dtypes.as_dtype(dtype) @@ -2278,10 +2068,10 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): outer_splits_shape = [None if outer_dim is None else outer_dim + 1] inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype) - specs = ( - [tensor_spec.TensorSpec(flat_values_shape, self._dtype), - tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)] + - [inner_splits_spec for _ in range(self._ragged_rank - 1)]) + specs = ([ + tensor_spec.TensorSpec(flat_values_shape, self._dtype), + tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype) + ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)]) return specs def _to_components(self, value): @@ -2301,7 +2091,10 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): tensor_list = [ops.convert_to_tensor(t) for t in tensor_list] result = tensor_list[0] for row_splits in reversed(tensor_list[1:]): - result = RaggedTensor(result, row_splits, internal=True) + result = RaggedTensor( + result, + RowPartition.from_row_splits(row_splits, validate=False), + internal=True) return result # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops @@ -2342,10 +2135,11 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): def _from_compatible_tensor_list(self, tensor_list): if self._ragged_rank < 0: - raise ValueError( - "ragged_rank must be non-negative; got %s." % self._ragged_rank) + raise ValueError("ragged_rank must be non-negative; got %s." % + self._ragged_rank) result = RaggedTensor._from_variant( # pylint: disable=protected-access - tensor_list[0], dtype=self._dtype, + tensor_list[0], + dtype=self._dtype, row_splits_dtype=self._row_splits_dtype, output_ragged_rank=self._ragged_rank) if self._shape.ndims is not None: @@ -2384,10 +2178,11 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec): @classmethod def from_value(cls, value): - return cls(shape=value.shape, - dtype=value.values.dtype, - ragged_rank=value.ragged_rank, - row_splits_dtype=value.row_splits.dtype) + return cls( + shape=value.shape, + dtype=value.values.dtype, + ragged_rank=value.ragged_rank, + row_splits_dtype=value.row_splits.dtype) type_spec.register_type_spec_from_value_converter( @@ -2654,28 +2449,22 @@ def _prod(lst): return functools.reduce(operator.mul, lst, 1) -def _get_row_partition_type_tensor_pairs_tail(rt_value): - """Gets a list of the row partitions for rt_value. +def _get_row_partition_type_tensor_pairs_tail(partition): + """Gets a row partition type tensor pair for the tail. - If parent_indices are defined, then they are used. Otherwise, row_splits + If value_rowid is defined, then it is used. Otherwise, row_splits are used. - This assumes that rt_input is nested inside another RaggedTensor. If it is - a tensor, then return an empty list. - Args: - rt_value: a ragged tensor value. May be a tensor. + partition: a RowPartition. Returns: A list of (row_partition_type, row_partition_tensor) pairs. """ - if isinstance(rt_value, RaggedTensor): - tail = _get_row_partition_type_tensor_pairs_tail(rt_value.values) - if rt_value._cached_value_rowids is not None: # pylint: disable=protected-access - return [("VALUE_ROWIDS", rt_value.value_rowids())] + tail - else: - return [("ROW_SPLITS", rt_value.row_splits)] + tail - return [] + if partition.has_cached_value_rowids(): + return ("VALUE_ROWIDS", partition.value_rowids()) + else: + return ("ROW_SPLITS", partition.row_splits) def _get_row_partition_type_tensor_pairs(rt_input): @@ -2691,12 +2480,22 @@ def _get_row_partition_type_tensor_pairs(rt_input): Returns: A list of (row_partition_type, row_partition_tensor) pairs. """ - tail = _get_row_partition_type_tensor_pairs_tail(rt_input.values) - if rt_input._cached_value_rowids is not None: # pylint: disable=protected-access - return [("FIRST_DIM_SIZE", rt_input.nrows()), - ("VALUE_ROWIDS", rt_input.value_rowids())] + tail + partitions = _get_nested_row_partitions(rt_input) + tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]] + + if partitions[0]._cached_value_rowids is not None: # pylint: disable=protected-access + return [("FIRST_DIM_SIZE", partitions[0].nrows()), + ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail else: - return [("ROW_SPLITS", rt_input.row_splits)] + tail + return [("ROW_SPLITS", partitions[0].row_splits)] + tail + + +def _get_nested_row_partitions(rt_input): + """Returns the ragged partitions from all levels of a tree.""" + if isinstance(rt_input.values, RaggedTensor): + return [rt_input.row_partition] + _get_nested_row_partitions( + rt_input.values) + return [rt_input.row_partition] def _shape_as_tensor(shape, dtype): @@ -2738,4 +2537,23 @@ def _shape_as_tensor(shape, dtype): return constant_op.constant(shape, dtype=dtype) +def _nvals_uniform_row_length(values, uniform_row_length): + """Get the number of values for uniform row length constructor.""" + const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value + if const_nvals is not None: + nvals = constant_op.constant(const_nvals, uniform_row_length.dtype) + elif isinstance(values, RaggedTensor): + nvals = values.nrows(out_type=uniform_row_length.dtype) + else: + nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0] + return nvals + + +def _get_optional_partition_dtype(values): + """Returns the partition dtype, or None if None exists.""" + if isinstance(values, RaggedTensor): + return values.row_partition.dtype + return None + + ops.no_gradient("RaggedTensorToVariant") diff --git a/tensorflow/python/ops/ragged/ragged_tensor_test.py b/tensorflow/python/ops/ragged/ragged_tensor_test.py index abd991dfce6..8a39dfc8db9 100644 --- a/tensorflow/python/ops/ragged/ragged_tensor_test.py +++ b/tensorflow/python/ops/ragged/ragged_tensor_test.py @@ -40,6 +40,8 @@ from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.ops.ragged import ragged_tensor_value from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec +from tensorflow.python.ops.ragged.row_partition import RowPartition + from tensorflow.python.platform import googletest @@ -128,8 +130,7 @@ def int32array(values): @test_util.run_all_in_graph_and_eager_modes -class RaggedTensorTest(test_util.TensorFlowTestCase, - parameterized.TestCase): +class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase): longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name #============================================================================= @@ -161,25 +162,22 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, outer_rt = RaggedTensor.from_row_splits( values=inner_rt, row_splits=[0, 3, 3, 5]) self.assertEqual(outer_rt.ragged_rank, 2) - self.assertAllEqual( - outer_rt, - [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) + self.assertAllEqual(outer_rt, + [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) del inner_rt, outer_rt # From section: "Multiple Ragged Dimensions" rt = RaggedTensor.from_nested_row_splits( flat_values=[3, 1, 4, 1, 5, 9, 2, 6], nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])) - self.assertAllEqual( - rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) + self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]) del rt # From section: "Uniform Inner Dimensions" rt = RaggedTensor.from_row_splits( values=array_ops.ones([5, 3]), row_splits=[0, 2, 5]) self.assertAllEqual( - rt, - [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) + rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]) self.assertEqual(rt.shape.as_list(), [2, None, 3]) del rt @@ -223,42 +221,29 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, def testRaggedTensorConstruction(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) - rt = RaggedTensor(values=values, row_splits=row_splits, internal=True) + rp = RowPartition(row_splits=row_splits, internal=True) + rt = RaggedTensor(values=values, row_partition=rp, internal=True) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testRaggedTensorConstructionErrors(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + rp = RowPartition(row_splits=row_splits, internal=True) with self.assertRaisesRegexp(ValueError, 'RaggedTensor constructor is private'): - RaggedTensor(values=values, row_splits=row_splits) + RaggedTensor(values=values, row_partition=rp) with self.assertRaisesRegexp(TypeError, 'values must be a Tensor or RaggedTensor'): - RaggedTensor(values=range(7), row_splits=row_splits, internal=True) + RaggedTensor(values=range(7), row_partition=rp, internal=True) with self.assertRaisesRegexp(TypeError, - 'Row-partitioning argument must be a Tensor'): - RaggedTensor(values=values, row_splits=[0, 2, 2, 5, 6, 7], internal=True) - - with self.assertRaisesRegexp(ValueError, - r'Shape \(6, 1\) must have rank 1'): - RaggedTensor( - values=values, - row_splits=array_ops.expand_dims(row_splits, 1), - internal=True) - - with self.assertRaisesRegexp(TypeError, - 'Cached value must be a Tensor or None.'): - RaggedTensor( - values=values, - row_splits=row_splits, - cached_row_lengths=[2, 3, 4], - internal=True) + 'row_partition must be a RowPartition'): + RaggedTensor(values=values, row_partition=[0, 2, 2, 5, 6, 7], + internal=True) #============================================================================= # RaggedTensor Factory Ops @@ -282,9 +267,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids self.assertAllEqual(rt_value_rowids, value_rowids) self.assertAllEqual(rt_nrows, 5) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testFromValueRowIdsWithDerivedNRowsDynamic(self): # nrows is not known at graph creation time. @@ -308,17 +292,16 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids self.assertAllEqual(rt_value_rowids, value_rowids) self.assertAllEqual(rt_nrows, 5) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testFromValueRowIdsWithExplicitNRows(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(7, dtypes.int64) - rt = RaggedTensor.from_value_rowids(values, value_rowids, nrows, - validate=False) + rt = RaggedTensor.from_value_rowids( + values, value_rowids, nrows, validate=False) self.assertEqual(rt.dtype, dtypes.string) self.assertEqual(rt.shape.as_list(), [7, None]) self.assertEqual(rt.ragged_rank, 1) @@ -331,16 +314,15 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids self.assertIs(rt_nrows, nrows) # cached_nrows self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []]) + rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []]) def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) nrows = constant_op.constant(5, dtypes.int64) - rt = RaggedTensor.from_value_rowids(values, value_rowids, nrows, - validate=False) + rt = RaggedTensor.from_value_rowids( + values, value_rowids, nrows, validate=False) self.assertEqual(rt.dtype, dtypes.string) self.assertEqual(rt.shape.as_list(), [5, None]) self.assertEqual(rt.ragged_rank, 1) @@ -354,9 +336,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_nrows, nrows) # cached_nrows self.assertAllEqual(rt_value_rowids, value_rowids) self.assertAllEqual(rt_nrows, nrows) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testFromValueRowIdsWithEmptyValues(self): rt = RaggedTensor.from_value_rowids([], []) @@ -385,9 +366,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_values, values) self.assertIs(rt_row_splits, row_splits) self.assertAllEqual(rt_nrows, 5) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testFromRowSplitsWithDifferentSplitTypes(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) @@ -428,9 +408,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_values, values) self.assertAllEqual(rt_nrows, 5) self.assertAllEqual(rt_row_starts, row_starts) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testFromRowLimits(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) @@ -448,9 +427,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_values, values) self.assertAllEqual(rt_nrows, 5) self.assertAllEqual(rt_row_limits, row_limits) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) def testFromRowLengths(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) @@ -469,21 +447,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_row_lengths, row_lengths) # cached_nrows self.assertAllEqual(rt_nrows, 5) self.assertAllEqual(rt_row_lengths, row_lengths) - self.assertAllEqual( - rt, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + self.assertAllEqual(rt, + [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + + def testFromRowLengthsInt32(self): + rt = RaggedTensor.from_row_lengths([1, 2, 3, 4], + constant_op.constant([1, 0, 3], + dtype=dtypes.int32)) + rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0]) + self.assertAllEqual([2, 1, 0], rt2.row_lengths()) def testFromUniformRowLength(self): values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] a1 = RaggedTensor.from_uniform_row_length(values, 2) a2 = RaggedTensor.from_uniform_row_length(values, 2, 8) - self.assertAllEqual(a1, [[1, 2], [3, 4], [5, 6], [7, 8], - [9, 10], [11, 12], [13, 14], [15, 16]]) + self.assertAllEqual( + a1, + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]]) self.assertAllEqual(a1, a2) self.assertEqual(a1.shape.as_list(), [8, 2]) self.assertEqual(a2.shape.as_list(), [8, 2]) - self.assertAllEqual(a1.uniform_row_length, 2) b1 = RaggedTensor.from_uniform_row_length(a1, 2) b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4) @@ -492,7 +476,6 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(b1, b2) self.assertEqual(b1.shape.as_list(), [4, 2, 2]) self.assertEqual(b2.shape.as_list(), [4, 2, 2]) - self.assertAllEqual(b1.uniform_row_length, 2) c1 = RaggedTensor.from_uniform_row_length(b1, 2) c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2) @@ -501,13 +484,11 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(c1, c2) self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2]) self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2]) - self.assertAllEqual(c1.uniform_row_length, 2) def testFromUniformRowLengthWithEmptyValues(self): empty_values = [] a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10) self.assertEqual(a.shape.as_list(), [10, 0]) - self.assertAllEqual(a.uniform_row_length, 0) b = RaggedTensor.from_uniform_row_length(a, 2) self.assertEqual(b.shape.as_list(), [5, 2, 0]) @@ -570,8 +551,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(rt_value_rowids, nested_value_rowids[0]) self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1]) self.assertAllEqual( - rt, - [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) + rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) def testFromNestedValueRowIdsWithExplicitNRows(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) @@ -602,9 +582,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1]) self.assertAllEqual(rt_nrows, nrows[0]) self.assertAllEqual(rt_values_nrows, nrows[1]) - self.assertAllEqual( - rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], - [[b'f'], [b'g'], []], [], []]) + self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], + [[b'f'], [b'g'], []], [], []]) def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self): values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) @@ -635,8 +614,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) ] - rt = RaggedTensor.from_nested_row_splits(flat_values, nested_row_splits, - validate=False) + rt = RaggedTensor.from_nested_row_splits( + flat_values, nested_row_splits, validate=False) self.assertEqual(rt.dtype, dtypes.string) self.assertEqual(rt.shape.as_list(), [4, None, None]) self.assertEqual(rt.ragged_rank, 2) @@ -650,8 +629,34 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertIs(rt_row_splits, nested_row_splits[0]) self.assertIs(rt_values_row_splits, nested_row_splits[1]) self.assertAllEqual( - rt, - [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) + rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) + + def testWithRowSplits(self): + flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g']) + nested_row_splits = [ + constant_op.constant([0, 2, 3, 3, 5], dtypes.int64), + constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + ] + + rt = RaggedTensor.from_nested_row_splits( + flat_values, nested_row_splits, validate=False) + + rt = rt.with_row_splits_dtype(dtypes.int32) + + self.assertEqual(rt.dtype, dtypes.string) + self.assertEqual(rt.shape.as_list(), [4, None, None]) + self.assertEqual(rt.ragged_rank, 2) + + rt_values = rt.values + rt_row_splits = rt.row_splits + rt_values_values = rt_values.values + rt_values_row_splits = rt_values.row_splits + + self.assertAllEqual(rt_values_values, flat_values) + self.assertAllEqual(rt_row_splits, nested_row_splits[0]) + self.assertAllEqual(rt_values_row_splits, nested_row_splits[1]) + self.assertAllEqual( + rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) def testFromNestedRowSplitsWithNonListInput(self): with self.assertRaisesRegexp(TypeError, @@ -747,16 +752,13 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, rt2 = RaggedTensor.from_value_rowids(values, value_rowids) for rt in [rt1, rt2]: - self.assertAllEqual( - rt, - [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], [[10, 11]], - [[12, 13]]]) + self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]], + [[10, 11]], [[12, 13]]]) self.assertAllEqual( rt.values, [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]) self.assertEqual(rt.values.shape.dims[0].value, 7) - self.assertAllEqual( - rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4]) + self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4]) self.assertAllEqual(rt.nrows(), 5) self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7]) self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6]) @@ -786,11 +788,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, for rt in [rt1, rt2]: self.assertAllEqual( - rt, - [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) + rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]]) self.assertAllEqual( - rt.values, - [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) + rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']]) self.assertEqual(rt.values.shape.dims[0].value, 5) self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3]) self.assertAllEqual(rt.nrows(), 4) @@ -798,9 +798,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3]) self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5]) self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2]) - self.assertAllEqual( - rt.flat_values, - [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) + self.assertAllEqual(rt.flat_values, + [b'a', b'b', b'c', b'd', b'e', b'f', b'g']) self.assertLen(rt.nested_row_splits, 2) self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5]) self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7]) @@ -1024,8 +1023,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, 'slice offsets must be integers or None'), # Tests for other errors - (SLICE_BUILDER[..., 0, 0, 0], IndexError, - 'Too many indices for RaggedTensor'), + (SLICE_BUILDER[..., 0, 0, + 0], IndexError, 'Too many indices for RaggedTensor'), ) def testRaggedTensorGetItemErrorsWithRaggedRank1(self, slice_spec, expected, message): @@ -1106,9 +1105,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, [[v[::-2] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]), (SLICE_BUILDER[..., ::-1, :], [[v[::-1] for v in row] for row in EXAMPLE_RAGGED_TENSOR_4D]), - (SLICE_BUILDER[..., ::-1], - [[[v[::-1] for v in col] for col in row] - for row in EXAMPLE_RAGGED_TENSOR_4D]), + (SLICE_BUILDER[..., ::-1], [[[v[::-1] for v in col] for col in row] + for row in EXAMPLE_RAGGED_TENSOR_4D]), ) def testRaggedTensorGetItemWithRaggedRank2(self, slice_spec, expected): """Test that rt.__getitem__(slice_spec) == expected.""" @@ -1212,11 +1210,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, rt_newaxis4 = rt[:, :, :, :, array_ops.newaxis] self.assertAllEqual( - rt, - [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]) + rt, [[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]) self.assertAllEqual( - rt_newaxis0, - [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]]) + rt_newaxis0, [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]], []]]) self.assertAllEqual( rt_newaxis1, [[[[[b'a', b'b'], [b'c', b'd']], [], [[b'e', b'f']]]], [[]]]) @@ -1330,9 +1326,10 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, else: expected_repr = ( 'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", ' - 'shape=(7,), dtype=string), row_splits=' - 'Tensor("RaggedFromRowSplits/row_splits:0", ' - 'shape=(6,), dtype={}))').format(splits_type) + 'shape=(7,), dtype=string), ' + 'row_splits=Tensor(' + '"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",' + ' shape=(6,), dtype={}))').format(splits_type) self.assertEqual(repr(rt), expected_repr) self.assertEqual(str(rt), expected_repr) @@ -1362,15 +1359,11 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10) rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1)) - self.assertAllEqual( - rt1_plus_10, - [[11, 12], [13, 14, 15], [16], [], [17]]) - self.assertAllEqual( - rt2_times_10, - [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]]) - self.assertAllEqual( - rt1_expanded, - [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]]) + self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]]) + self.assertAllEqual(rt2_times_10, + [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]]) + self.assertAllEqual(rt1_expanded, + [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]]) #============================================================================= # Session.run @@ -1465,6 +1458,99 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, ragged_math_ops.reduce_sum(a) self.assertLen(a.consumers(), 1) + @parameterized.parameters([ + { + 'descr': 'from_value_rowids', + 'factory': RaggedTensor.from_value_rowids, + 'test': RaggedTensor.value_rowids, + 'values': { + 'values': [1, 2, 3, 4, 5, 6], + 'value_rowids': [0, 0, 1, 1, 2, 2], + }, + 'tensor_field': 'value_rowids', + 'value_rowids': [0, 1, 2], + 'nrows': 10 + }, + { + 'descr': 'from_row_splits', + 'factory': RaggedTensor.from_row_splits, + # row_splits is a property, not a function. + 'test': (lambda rt: rt.row_splits), + 'values': { + 'values': [1, 2, 3, 4, 5, 6], + 'row_splits': [0, 2, 4, 6], + }, + 'tensor_field': 'row_splits', + 'row_splits': [0, 1, 2, 3] + }, + { + 'descr': 'from_row_lengths', + 'factory': RaggedTensor.from_row_lengths, + 'test': RaggedTensor.row_lengths, + 'values': { + 'values': [1, 2, 3, 4, 5, 6], + 'row_lengths': [2, 2, 2], + }, + 'tensor_field': 'row_lengths', + 'row_lengths': [1, 1, 1], + }, + # from_row_starts + { + 'descr': 'from_row_starts', + 'factory': RaggedTensor.from_row_starts, + 'test': RaggedTensor.row_starts, + 'values': { + 'values': [1, 2, 3, 4, 5, 6], + 'row_starts': [0, 2, 4] + }, + 'tensor_field': 'row_starts', + 'row_starts': [0, 1, 2] + }, + # from_row_limits + { + 'descr': 'from_row_limits', + 'factory': RaggedTensor.from_row_limits, + 'test': RaggedTensor.row_limits, + 'values': { + 'values': [1, 2, 3, 4, 5, 6], + 'row_limits': [2, 4, 6] + }, + 'tensor_field': 'row_limits', + 'row_limits': [3] + }, + # from_uniform_row_length + { + 'descr': 'from_uniform_row_length', + 'factory': RaggedTensor.from_uniform_row_length, + # One cannot extract uniform_row_length or nvals, so we return + # nvals//nrows = uniform_row_length, where nvals = 3 + 'test': (lambda rt: 3 // (rt.shape[0])), + 'values': { + 'values': [1, 2, 3, 4, 5, 6], + 'uniform_row_length': 2 + }, + 'tensor_field': 'uniform_row_length', + 'uniform_row_length': 3 + }, + ]) + def testFactoryTypePreference(self, descr, test, factory, values, + tensor_field, **kwargs): + # When input tensors have shape information, some of these errors will be + # detected statically. + def op_cast(k, v): + if k == tensor_field: + return constant_op.constant(v, dtype=dtypes.int32) + else: + return v + + value_copy = {k: op_cast(k, v) for k, v in values.items()} + rt = factory(**value_copy) + + kw_copy = {k: v for k, v in kwargs.items()} + kw_copy['values'] = rt + rt2 = factory(**kw_copy) + self.assertAllEqual(kwargs[tensor_field], test(rt2)) + @parameterized.parameters([ # from_value_rowids { @@ -1557,7 +1643,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, 'row_lengths': [[1, 2], [1, 0]] }, { - 'descr': 'negative row_lengths', + 'descr': 'negatve row_lengths', 'factory': RaggedTensor.from_row_lengths, 'values': [1, 2, 3, 4], 'row_lengths': [3, -1, 2] @@ -1678,18 +1764,21 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, with self.assertRaises((errors.InvalidArgumentError, ValueError)): self.evaluate(factory(**kwargs)) - # Remove shape information (by wraping tensors in placeholders), and check + # Remove shape information (by wrapping tensors in placeholders), and check # that we detect the errors when the graph is run. if not context.executing_eagerly(): + def wrap_arg(v): return array_ops.placeholder_with_default( constant_op.constant(v, dtype=dtypes.int64), tensor_shape.TensorShape(None)) + kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items()) with self.assertRaises(errors.InvalidArgumentError): self.evaluate(factory(**kwargs)) + #============================================================================= # RaggedTensor Variant conversion #============================================================================= @@ -2059,8 +2148,10 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase, self.assertAllEqual(rt1, [[1, 2], [3]]) spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32) - rt2 = spec2._from_components([np.array([1, 2, 3]), np.array([0, 2, 3]), - np.array([0, 0, 2, 3])]) + rt2 = spec2._from_components( + [np.array([1, 2, 3]), + np.array([0, 2, 3]), + np.array([0, 0, 2, 3])]) self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue) self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]]) diff --git a/tensorflow/python/ops/ragged/row_partition.py b/tensorflow/python/ops/ragged/row_partition.py new file mode 100644 index 00000000000..db31af15455 --- /dev/null +++ b/tensorflow/python/ops/ragged/row_partition.py @@ -0,0 +1,843 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An internal class for representing the partition in a ragged tensor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.ragged import segment_id_ops + +# pylint: disable=protected-access +_eval_using_default_session = ops._eval_using_default_session +# pylint: enable=protected-access + +#=============================================================================== +# RowPartition +#=============================================================================== + + +class RowPartition(object): + """Represents the partition of a ragged tensor. + + In particular, this provides a ragged representation to a flat list, + or a deeper ragged representation of a ragged tensor. However, it does + not store the values or the substructure: only the top-level representation + is represented. + + The canonical representation of a partition is row_splits, which indicates how + the flat values are divided into rows. In particular, the values for row + `rt[i]` are stored in the slice + `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. + + ### Alternative Row-Partitioning Schemes + + In addition to `row_splits`, row partitions provide support for five other + partitioning schemes: + + * `row_lengths`: a vector with shape `[nrows]`, which specifies the length + of each row. + + * `value_rowids` and `nrows`: `value_rowids` is a vector with shape + `[nvals]`, corresponding one-to-one with `values`, which specifies + each value's row index. In particular, the row `rt[row]` consists of the + values `rt.values[j]` where `value_rowids[j]==row`. `nrows` is an + integer scalar that specifies the number of rows in the + `RowPartition`. (`nrows` is used to indicate trailing empty rows.) + + * `row_starts` (and nvals): a vector with shape `[nrows]`, which specifies + the start offset of each row. Equivalent to `row_splits[:-1]`. + + * `row_limits`: a vector with shape `[nrows]`, which specifies the stop + offset of each row. Equivalent to `row_splits[1:]`. + + * `uniform_row_length` (and nvals): A scalar tensor, specifying the length + of every row. This row-partitioning scheme may only be used if all rows + have the same length. + + For examples, please see the documentation on RaggedTensor. + """ + + #============================================================================= + # Constructor (private) + #============================================================================= + def __init__(self, + row_splits, + cached_row_lengths=None, + cached_value_rowids=None, + cached_nrows=None, + internal=False, + uniform_row_length=None): + """Creates a `RowPartition` with a specified partitioning for `values`. + + This constructor is private -- please use one of the following ops to + build `RowPartition`s: + + * `RowPartition.from_row_lengths` + * `RowPartition.from_value_rowids` + * `RowPartition.from_row_splits` + * `RowPartition.from_row_starts` + * `RowPartition.from_row_limits` + + Args: + row_splits: A 1-D integer tensor with shape `[nrows+1]`. + cached_row_lengths: A 1-D integer tensor with shape `[nrows]` + cached_value_rowids: A 1-D integer tensor with shape `[nvals]`. + cached_nrows: A 1-D integer scalar tensor. + internal: True if the constructor is being called by one of the factory + methods. If false, an exception will be raised. + uniform_row_length: A scalar tensor. + + Raises: + TypeError: If a row partitioning tensor has an inappropriate dtype. + TypeError: If exactly one row partitioning argument was not specified. + ValueError: If a row partitioning tensor has an inappropriate shape. + ValueError: If multiple partitioning arguments are specified. + ValueError: If nrows is specified but value_rowids is not None. + """ + if not internal: + raise ValueError("RaggedTensor constructor is private; please use one " + "of the factory methods instead (e.g., " + "RaggedTensor.from_row_lengths())") + + # Validate the arguments. + if not isinstance(row_splits, ops.Tensor): + raise TypeError("Row-partitioning argument must be a Tensor, got %r" % + row_splits) + if row_splits.dtype not in (dtypes.int32, dtypes.int64): + raise ValueError("Row-partitioning argument must be int32 or int64") + + # Validate shapes & dtypes. + row_splits.shape.assert_has_rank(1) + row_splits.set_shape([None]) + self._row_splits = row_splits + + # Store any cached tensors. These are used to avoid unnecessary + # round-trip conversions when a RaggedTensor is constructed from + # lengths or rowids, and we later want those lengths/rowids back. + for tensor in [cached_row_lengths, cached_value_rowids, cached_nrows]: + if tensor is not None: + if not isinstance(tensor, ops.Tensor): + raise TypeError("Cached value must be a Tensor or None.") + elif tensor.dtype not in (dtypes.int32, dtypes.int64): + raise TypeError("Cached value must be int32 or int64.") + self._cached_row_lengths = cached_row_lengths + self._cached_value_rowids = cached_value_rowids + self._cached_nrows = cached_nrows + + if uniform_row_length is not None: + if not isinstance(uniform_row_length, ops.Tensor): + raise TypeError("uniform_row_length must be a Tensor or None.") + elif uniform_row_length.dtype not in (dtypes.int32, dtypes.int64): + raise TypeError("uniform_row_length must be int32 or int64.") + self._uniform_row_length = uniform_row_length + + #============================================================================= + # Factory Methods + #============================================================================= + + @classmethod + def from_value_rowids(cls, + value_rowids, + nrows=None, + name=None, + validate=True, + preferred_dtype=None): + """Creates a `RowPartition` with rows partitioned by `value_rowids`. + + The implied `RaggedTensor` corresponds with the python list defined by: + + ```python + result = [[values[i] for i in range(len(values)) if value_rowids[i] == row] + for row in range(nrows)] + ``` + + Args: + value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds + one-to-one with `values`, and specifies each value's row index. Must be + nonnegative, and must be sorted in ascending order. + nrows: An integer scalar specifying the number of rows. This should be + specified if the `RaggedTensor` may containing empty training rows. Must + be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty). + Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty). + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form a + valid `RowPartition`. + preferred_dtype: The dtype to encode value_rowids if it doesn't already + have one. The default is tf.int64. + + Returns: + A `RowPartition`. + + Raises: + ValueError: If `nrows` is incompatible with `value_rowids`. + + #### Example: + + >>> print(RowPartition.from_value_rowids( + ... value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], + ... nrows=4)) + tf.RowPartition(row_splits=tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)) + """ + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + with ops.name_scope(name, "RowPartitionFromValueRowIds", + [value_rowids, nrows]): + value_rowids = cls._convert_row_partition(value_rowids, "value_rowids", + preferred_dtype) + if nrows is None: + const_rowids = tensor_util.constant_value(value_rowids) + if const_rowids is None: + nrows = array_ops.concat([value_rowids[-1:], [-1]], axis=0)[0] + 1 + const_nrows = None + else: + const_nrows = const_rowids[-1] + 1 if const_rowids.size > 0 else 0 + nrows = ops.convert_to_tensor( + const_nrows, value_rowids.dtype, name="nrows") + else: + nrows = ops.convert_to_tensor(nrows, value_rowids.dtype, "nrows") + const_nrows = tensor_util.constant_value(nrows) + if const_nrows is not None: + if const_nrows < 0: + raise ValueError("Expected nrows >= 0; got %d" % const_nrows) + const_rowids = tensor_util.constant_value(value_rowids) + if const_rowids is not None and const_rowids.size > 0: + if not const_nrows >= const_rowids[-1] + 1: + raise ValueError( + "Expected nrows >= value_rowids[-1] + 1; got nrows=%d, " + "value_rowids[-1]=%d" % (const_nrows, const_rowids[-1])) + + value_rowids.shape.assert_has_rank(1) + nrows.shape.assert_has_rank(0) + + if validate: + msg = ("Arguments to from_value_rowids do not form a valid " + "RowPartition") + checks = [ + check_ops.assert_rank(value_rowids, 1, message=msg), + check_ops.assert_rank(nrows, 0, message=msg), + check_ops.assert_non_negative(value_rowids[:1], message=msg), + _assert_monotonic_increasing(value_rowids, message=msg), + check_ops.assert_less(value_rowids[-1:], nrows, message=msg), + ] + value_rowids = control_flow_ops.with_dependencies(checks, value_rowids) + + # Convert value_rowids & nrows to row_splits. + # Note: we don't use segment_ids_to_row_splits() here because we want + # to save the intermediate value `row_lengths`, so we can cache it. + # TODO(b/116708836) Upgrade bincount to accept int64 so we can skip the + # cast. + value_rowids_int32 = math_ops.cast(value_rowids, dtypes.int32) + nrows_int32 = math_ops.cast(nrows, dtypes.int32) + row_lengths = math_ops.bincount( + value_rowids_int32, + minlength=nrows_int32, + maxlength=nrows_int32, + dtype=value_rowids.dtype) + row_splits = array_ops.concat([[0], math_ops.cumsum(row_lengths)], axis=0) + if const_nrows is not None: + row_lengths.set_shape([const_nrows]) + row_splits.set_shape([const_nrows + 1]) + + return cls( + row_splits, + cached_row_lengths=row_lengths, + cached_value_rowids=value_rowids, + cached_nrows=nrows, + internal=True) + + @classmethod + def from_row_splits(cls, + row_splits, + name=None, + validate=True, + preferred_dtype=None): + """Creates a `RowPartition` with rows partitioned by `row_splits`. + + A `RaggedTensor` constructed with this corresponds with the python list + defined by: + + ```python + result = [values[row_splits[i]:row_splits[i + 1]] + for i in range(len(row_splits) - 1)] + ``` + + Args: + row_splits: A 1-D integer tensor with shape `[nrows+1]`. Must not be + empty, and must be sorted in ascending order. `row_splits[0]` must be + zero. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form a + valid `RowPartition`. + preferred_dtype: If row_splits has an unspecified type, use this one. If + preferred_dtype is None, defaults to dtypes.int64. + + Returns: + A `RowPartition`. + + Raises: + ValueError: If `row_splits` is an empty list. + + """ + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + if isinstance(row_splits, (list, tuple)) and not row_splits: + raise ValueError("row_splits tensor may not be empty.") + if isinstance(row_splits, tensor_spec.TensorSpec): + return cls(row_splits=row_splits, internal=True) + + with ops.name_scope(name, "RowPartitionFromRowSplits", [row_splits]): + row_splits = cls._convert_row_partition(row_splits, "row_splits", + preferred_dtype) + row_splits.shape.assert_has_rank(1) + + if validate: + msg = "Arguments to from_row_splits do not form a valid RaggedTensor:" + checks = [ + check_ops.assert_rank(row_splits, 1, message=(msg + "rank")), + _assert_zero(row_splits[0], message=(msg + "zero")), + _assert_monotonic_increasing( + row_splits, message=(msg + "monotonic")), + ] + row_splits = control_flow_ops.with_dependencies(checks, row_splits) + + return cls(row_splits=row_splits, internal=True) + + @classmethod + def from_row_lengths(cls, + row_lengths, + name=None, + validate=True, + preferred_dtype=None): + """Creates a `RowPartition` with rows partitioned by `row_lengths`. + + A `RaggedTensor` constructed with this corresponds with the python list + defined by: + + ```python + result = [[values.pop(0) for i in range(length)] + for length in row_lengths] + ``` + + Args: + row_lengths: A 1-D integer tensor with shape `[nrows]`. Must be + nonnegative. + name: A name prefix for the RowPartition (optional). + validate: If true, then use assertions to check that the arguments form a + valid `RowPartition`. + preferred_dtype: If row_lengths has an unspecified type, use this one. If + preferred_dtype is None, defaults to dtypes.int64. + + Returns: + A `RowPartition`. + """ + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + with ops.name_scope(name, "RowPartitionFromRowLengths", [row_lengths]): + row_lengths = cls._convert_row_partition(row_lengths, "row_lengths", + preferred_dtype) + row_lengths.shape.assert_has_rank(1) + + if validate: + msg = "Arguments to from_row_lengths do not form a valid RowPartition" + checks = [ + check_ops.assert_rank(row_lengths, 1, message=msg), + check_ops.assert_non_negative(row_lengths, message=msg), + ] + row_lengths = control_flow_ops.with_dependencies(checks, row_lengths) + + row_limits = math_ops.cumsum(row_lengths) + row_splits = array_ops.concat([[0], row_limits], axis=0) + return cls( + row_splits=row_splits, cached_row_lengths=row_lengths, internal=True) + + @classmethod + def from_row_starts(cls, + row_starts, + nvals, + name=None, + validate=True, + preferred_dtype=None): + """Creates a `RowPartition` with rows partitioned by `row_starts`. + + Equivalent to: `from_row_splits(concat([row_starts, nvals]))`. + + Args: + row_starts: A 1-D integer tensor with shape `[nrows]`. Must be + nonnegative and sorted in ascending order. If `nrows>0`, then + `row_starts[0]` must be zero. + nvals: A scalar tensor indicating the number of values. + name: A name prefix for the RowPartition (optional). + validate: If true, then use assertions to check that the arguments form a + valid `RowPartition`. + preferred_dtype: If row_limits has an unspecified type, use this one. If + preferred_dtype is None, defaults to dtypes.int64. + + Returns: + A `RowPartition`. + """ + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + with ops.name_scope(name, "RowPartitionFromRowStarts", [row_starts]): + row_starts = cls._convert_row_partition(row_starts, "row_starts", + preferred_dtype) + row_starts.shape.assert_has_rank(1) + nvals = math_ops.cast(nvals, row_starts.dtype) + if validate: + msg = "Arguments to from_row_starts do not form a valid RaggedTensor" + checks = [ + check_ops.assert_rank(row_starts, 1, message=msg), + _assert_zero(row_starts[:1], message=msg), + _assert_monotonic_increasing(row_starts, message=msg), + check_ops.assert_less_equal(row_starts[-1:], nvals, message=msg), + ] + row_starts = control_flow_ops.with_dependencies(checks, row_starts) + + row_splits = array_ops.concat([row_starts, [nvals]], axis=0) + return cls(row_splits=row_splits, internal=True) + + def has_cached_value_rowids(self): + return self._cached_value_rowids is not None + + @classmethod + def from_row_limits(cls, + row_limits, + name=None, + validate=True, + preferred_dtype=None): + """Creates a `RowPartition` with rows partitioned by `row_limits`. + + Equivalent to: `from_row_splits(values, concat([0, row_limits]))`. + + Args: + row_limits: A 1-D integer tensor with shape `[nrows]`. Must be sorted in + ascending order. + name: A name prefix for the RaggedTensor (optional). + validate: If true, then use assertions to check that the arguments form a + valid `RowPartition`. + preferred_dtype: If row_limits has an unspecified type, use this one. If + preferred_dtype is None, defaults to dtypes.int64. + Returns: + A `RowPartition`. + """ + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + with ops.name_scope(name, "RowPartitionFromRowLimits", [row_limits]): + row_limits = cls._convert_row_partition(row_limits, "row_limits", + preferred_dtype) + row_limits.shape.assert_has_rank(1) + + if validate: + msg = "Arguments to from_row_limits do not form a valid RaggedTensor" + checks = [ + check_ops.assert_rank(row_limits, 1, message=msg), + check_ops.assert_non_negative(row_limits[:1], message=msg), + _assert_monotonic_increasing(row_limits, message=msg), + ] + row_limits = control_flow_ops.with_dependencies(checks, row_limits) + + zero = array_ops.zeros([1], row_limits.dtype) + row_splits = array_ops.concat([zero, row_limits], axis=0) + return cls(row_splits=row_splits, internal=True) + + @classmethod + def from_uniform_row_length(cls, + nvals, + uniform_row_length, + nrows=None, + validate=True, + name=None, + preferred_dtype=None): + """Creates a `RowPartition` with rows partitioned by `uniform_row_length`. + + A `RaggedTensor` constructed with this corresponds with the python list + defined by (assuming uniform_row_length and nvals nonzero): + + ```python + result = [[values.pop(0) for _ in range(uniform_row_length)] + for _ in range(nrows)] + ``` + + + Note that `rt1` only contains one ragged dimension (the innermost + dimension). In contrast, if `from_row_splits` is used to construct a similar + `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions: + + Args: + nvals: a non-negative scalar integer tensor for the number of values. + uniform_row_length: A scalar integer tensor. Must be nonnegative. The + size of the outer axis of `values` must be evenly divisible by + `uniform_row_length`. + nrows: The number of rows in the constructed RaggedTensor. If not + specified, then it defaults to `nvals/uniform_row_length` (or `0` if + `uniform_row_length==0`). `nrows` only needs to be specified if + `uniform_row_length` might be zero. `uniform_row_length*nrows` must be + `nvals`. + validate: If true, then use assertions to check that the arguments form a + valid `RaggedTensor`. + name: A name prefix for the RaggedTensor (optional) + preferred_dtype: if uniform_row_length has no dtype, use this one. + + Returns: + A `RowPartition`. + """ + if not isinstance(validate, bool): + raise TypeError("validate must have type bool") + with ops.name_scope(name, "RowPartitionFromUniformRowLength", + [uniform_row_length, nrows]): + uniform_row_length = cls._convert_row_partition(uniform_row_length, + "uniform_row_length", + preferred_dtype) + uniform_row_length.shape.assert_has_rank(0) + + # Find nrows. + const_row_length = tensor_util.constant_value(uniform_row_length) + if nrows is None: + if const_row_length is None: + # Avoid division by zero if uniform_row_length==0 (and nvals==0). + rowlen_or_1 = control_flow_ops.cond( + math_ops.equal(uniform_row_length, 0), + lambda: constant_op.constant(1, uniform_row_length.dtype), + lambda: uniform_row_length) + nrows = nvals // rowlen_or_1 + elif const_row_length == 0: + nrows = 0 + else: + nrows = nvals // const_row_length + nrows = ops.convert_to_tensor( + nrows, uniform_row_length.dtype, name="nrows") + const_nrows = tensor_util.constant_value(nrows) + const_nvals = tensor_util.constant_value(nvals) + + # Find row_splits. + if const_nrows is not None and const_row_length is not None: + row_splits = [v * const_row_length for v in range(const_nrows + 1)] + row_splits = constant_op.constant(row_splits, uniform_row_length.dtype) + else: + row_splits = math_ops.range(nrows + 1) * uniform_row_length + + if validate: + checks = [] + + if (const_nrows is None or const_row_length is None or + const_nvals is None): + checks.append( + check_ops.assert_equal( + nrows * uniform_row_length, nvals, + ("uniform_row_length", uniform_row_length, "times nrows", + nrows, "must equal nvals", nvals))) + else: + if const_nrows * const_row_length != const_nvals: + raise ValueError( + "uniform_row_length=%d times nrows=%d must equal nvals=%d" % + (const_row_length, const_nrows, const_nvals)) + + if uniform_row_length.shape.rank is None: + checks.append( + check_ops.assert_rank( + uniform_row_length, + 0, + message="uniform_row_length must be a scalar.")) + + const_row_length = tensor_util.constant_value(uniform_row_length) + if const_row_length is None: + checks.append( + check_ops.assert_greater_equal( + uniform_row_length, + constant_op.constant(0, uniform_row_length.dtype), + message="uniform_row_length must be >= 0.")) + else: + if const_row_length < 0: + raise ValueError("uniform_row_length must be >= 0.") + + row_splits = control_flow_ops.with_dependencies(checks, row_splits) + + return cls( + row_splits=row_splits, + uniform_row_length=uniform_row_length, + cached_nrows=nrows, + internal=True) + + @classmethod + def _convert_row_partition(cls, partition, name, preferred_dtype): + """Converts `partition` to Tensors. + + Args: + partition: A row-partitioning tensor for the `RowPartition` being + constructed. I.e., one of: row_splits, row_lengths, row_starts, + row_limits, value_rowids. + name: The name of the row-partitioning tensor. + preferred_dtype: If partition has no dtype, give it this one. If + no dtype is specified, use dtypes.int64. + + Returns: + A tensor equivalent to partition. + + Raises: + ValueError: if dtype is not int32 or int64. + """ + if preferred_dtype is None: + preferred_dtype = dtypes.int64 + if isinstance(partition, np.ndarray) and partition.dtype == np.int32: + partition = ops.convert_to_tensor(partition, name=name) + else: + partition = ops.convert_to_tensor( + partition, preferred_dtype=preferred_dtype, name=name) + if partition.dtype not in (dtypes.int32, dtypes.int64): + raise ValueError("%s must have dtype int32 or int64" % name) + + return partition + + def with_dependencies(self, dependencies): + """Returns a new RowPartition equal to self with control dependencies. + + Specifically, self._row_splits is gated by the given control dependencies. + Used to add sanity checks to the constructors. + + Args: + dependencies: a list of tensors to use as dependencies. + + Returns: + A new RowPartition object. + """ + new_row_splits = control_flow_ops.with_dependencies(dependencies, + self._row_splits) + return RowPartition( + row_splits=new_row_splits, + cached_row_lengths=self._cached_row_lengths, + cached_value_rowids=self._cached_value_rowids, + cached_nrows=self._cached_nrows, + internal=True, + uniform_row_length=self._uniform_row_length) + + #============================================================================= + # Accessors + #============================================================================= + + @property + def dtype(self): + """The `DType` of the row partition.""" + return self._row_splits.dtype + + @property + def row_splits(self): + """The row-split indices for this row partition. + + `rt.row_splits` specifies where the values for each row begin and end in + `rt.values`. In particular, the values for row `rt[i]` are stored in + the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`. + + Returns: + A 1-D integer `Tensor` with shape `[self.nrows+1]`. + The returned tensor is non-empty, and is sorted in ascending order. + `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to + `self.values.shape[0]`. + + """ + return self._row_splits + + def value_rowids(self, name=None): + """Returns the row indices for this row partition. + + Returns a vector with a number of entries equal to nvals, where + the ith value in the tensor indicates the row of the ith value. + + Args: + name: A name prefix for the returned tensor (optional). + + Returns: + A 1-D integer `Tensor` with shape `self.values.shape[:1]`. + The returned tensor is nonnegative, and is sorted in ascending order. + + """ + if self._cached_value_rowids is not None: + return self._cached_value_rowids + + with ops.name_scope(name, "RaggedValueRowIds", [self]): + return segment_id_ops.row_splits_to_segment_ids(self.row_splits) + + def nrows_as_dimension(self): + """Returns the first dimension of the shape as a `tf.Dimension`.""" + return tensor_shape.dimension_at_index(self._row_splits.shape, 0) - 1 + + def nvals(self, out_type=None, name=None): + """Returns the number of values in this row partition. + + Specifically, should be equal to the outermost dimension of the + values associated with this row partition. + + Args: + out_type: `dtype` for the returned tensor. Defaults to + `self.row_splits.dtype`. + name: A name prefix for the returned tensor (optional). + + Returns: + the number of values in this row partition as a tensor scalar. + """ + if out_type is None: + return self.row_splits[-1] + else: + out_type = dtypes.as_dtype(out_type) + return math_ops.cast(self.row_splits[-1], name=name, dtype=out_type) + + def nrows(self, out_type=None, name=None): + """Returns the number of rows in this ragged tensor. + + I.e., the size of the outermost dimension of the tensor. + + Args: + out_type: `dtype` for the returned tensor. Defaults to + `self.row_splits.dtype`. + name: A name prefix for the returned tensor (optional). + + Returns: + A scalar `Tensor` with dtype `out_type`. + + """ + if out_type is None: + out_type = self._row_splits.dtype + else: + out_type = dtypes.as_dtype(out_type) + if self._cached_nrows is not None: + return math_ops.cast(self._cached_nrows, out_type) + with ops.name_scope(name, "RaggedNRows", [self]): + nsplits = tensor_shape.dimension_at_index(self.row_splits.shape, 0) + if nsplits.value is None: + return array_ops.shape(self.row_splits, out_type=out_type)[0] - 1 + else: + return constant_op.constant(nsplits.value - 1, dtype=out_type) + + def uniform_row_length(self): + """Returns the uniform row length, or `None` if unspecified.""" + return self._uniform_row_length + + def row_starts(self, name=None): + """Returns the start indices for rows in this row partition. + + These indices specify where the values for each row begin in + `self.values`. `rt.row_starts()` is equal to `rt.row_splits[:-1]`. + + Args: + name: A name prefix for the returned tensor (optional). + + Returns: + A 1-D integer Tensor with shape `[nrows]`. + The returned tensor is nonnegative, and is sorted in ascending order. + """ + with ops.name_scope(name, "RaggedRowStarts", [self]): + return self.row_splits[:-1] + + def row_limits(self, name=None): + """Returns the limit indices for rows in this row partition. + + These indices specify where the values for each row end in + `self.values`. `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`. + + Args: + name: A name prefix for the returned tensor (optional). + + Returns: + A 1-D integer Tensor with shape `[nrows]`. + The returned tensor is nonnegative, and is sorted in ascending order. + """ + with ops.name_scope(name, "RaggedRowLimits", [self]): + return self.row_splits[1:] + + def row_lengths(self, name=None): + if self._cached_row_lengths is not None: + return self._cached_row_lengths + splits = self.row_splits + with ops.name_scope(name, "RaggedRowLengths", [self]): + return splits[1:] - splits[:-1] + + #============================================================================= + # Transformation + #============================================================================= + + def with_row_splits_dtype(self, dtype): + """Returns a copy of this RowPartition with the given `row_splits` dtype. + + For RaggedTensors with multiple ragged dimensions, the `row_splits` for all + nested `RaggedTensor` objects are cast to the given dtype. + + Args: + dtype: The dtype for `row_splits`. One of `tf.int32` or `tf.int64`. + + Returns: + A copy of this RaggedTensor, with the `row_splits` cast to the given + type. + """ + dtype = dtypes.as_dtype(dtype) + if dtype not in (dtypes.int32, dtypes.int64): + raise ValueError("dtype must be int32 or int64") + if self._row_splits.dtype == dtype: + return self + + row_splits = math_ops.cast(self._row_splits, dtype) + + cached_row_lengths = self._cached_row_lengths + if cached_row_lengths is not None: + cached_row_lengths = math_ops.cast(cached_row_lengths, dtype) + cached_value_rowids = self._cached_value_rowids + if cached_value_rowids is not None: + cached_value_rowids = math_ops.cast(cached_value_rowids, dtype) + cached_nrows = self._cached_nrows + if cached_value_rowids is not None: + cached_value_rowids = math_ops.cast(cached_value_rowids, dtype) + uniform_row_length = self._uniform_row_length + if uniform_row_length is not None: + uniform_row_length = math_ops.cast(uniform_row_length, dtype) + + return RowPartition( + row_splits, + cached_row_lengths, + cached_value_rowids, + cached_nrows, + internal=True, + uniform_row_length=uniform_row_length) + +#============================================================================= +# String Encoding +#============================================================================= + + def __repr__(self): + return "tf.RowPartition(row_splits=%s)" % (self._row_splits) + + +#=============================================================================== +# Helper Functions +#=============================================================================== + + +def _assert_monotonic_increasing(tensor, message=None): + return check_ops.assert_non_negative( + tensor[1:] - tensor[:-1], message=message) + + +def _assert_zero(tensor, message=None): + return check_ops.assert_equal( + tensor, constant_op.constant(0, dtype=tensor.dtype), message=message) diff --git a/tensorflow/python/ops/ragged/row_partition_test.py b/tensorflow/python/ops/ragged/row_partition_test.py new file mode 100644 index 00000000000..e429facd21d --- /dev/null +++ b/tensorflow/python/ops/ragged/row_partition_test.py @@ -0,0 +1,559 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for third_party.tensorflow.python.ops.ragged_tensor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from absl.testing import parameterized +import numpy as np + +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.ragged.row_partition import RowPartition +from tensorflow.python.platform import googletest + + +class _SliceBuilder(object): + """Helper to construct arguments for __getitem__. + + Usage: _SliceBuilder()[] slice_spec Python generates for . + """ + + def __getitem__(self, slice_spec): + return slice_spec + + +SLICE_BUILDER = _SliceBuilder() + + +def _make_tensor_slice_spec(slice_spec, use_constant=True): + """Wraps all integers in an extended slice spec w/ a tensor. + + This function is used to help test slicing when the slice spec contains + tensors, rather than integers. + + Args: + slice_spec: The extended slice spec. + use_constant: If true, then wrap each integer with a tf.constant. If false, + then wrap each integer with a tf.placeholder. + + Returns: + A copy of slice_spec, but with each integer i replaced with tf.constant(i). + """ + + def make_piece_scalar(piece): + if isinstance(piece, int): + scalar = constant_op.constant(piece) + if use_constant: + return scalar + else: + return array_ops.placeholder_with_default(scalar, []) + elif isinstance(piece, slice): + return slice( + make_piece_scalar(piece.start), make_piece_scalar(piece.stop), + make_piece_scalar(piece.step)) + else: + return piece + + if isinstance(slice_spec, tuple): + return tuple(make_piece_scalar(piece) for piece in slice_spec) + else: + return make_piece_scalar(slice_spec) + + +# Example 2D ragged tensor value with one ragged dimension and with scalar +# values, expressed as nested python lists and as splits+values. +EXAMPLE_RAGGED_TENSOR_2D = [[b'a', b'b'], [b'c', b'd', b'e'], [b'f'], [], + [b'g']] +EXAMPLE_RAGGED_TENSOR_2D_SPLITS = [0, 2, 5, 6, 6, 7] +EXAMPLE_RAGGED_TENSOR_2D_VALUES = ['a', 'b', 'c', 'd', 'e', 'f', 'g'] + +# Example 4D ragged tensor value, with two ragged dimensions and with values +# whose shape is [2], expressed as nested python lists and as splits+values. +EXAMPLE_RAGGED_TENSOR_4D = [ + [ # rt[0] + [[1, 2], [3, 4], [5, 6]], # rt[0][0] + [[7, 8], [9, 10], [11, 12]]], # rt[0][1] + [], # rt[1] + [ # rt[2] + [[13, 14], [15, 16], [17, 18]]], # rt[2][0] + [ # rt[3] + [[19, 20]]] # rt[3][0] +] # pyformat: disable +EXAMPLE_RAGGED_TENSOR_4D_SPLITS1 = [0, 2, 2, 3, 4] +EXAMPLE_RAGGED_TENSOR_4D_SPLITS2 = [0, 3, 6, 9, 10] +EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], + [11, 12], [13, 14], [15, 16], [17, 18], + [19, 20]] + +# Example 3D ragged tensor with uniform_row_lengths. +EXAMPLE_RAGGED_TENSOR_3D = [[[1, 2, 3], [4], [5, 6]], [[], [7, 8, 9], []]] +EXAMPLE_RAGGED_TENSOR_3D_ROWLEN = 3 +EXAMPLE_RAGGED_TENSOR_3D_SPLITS = [0, 3, 4, 6, 6, 9, 9] +EXAMPLE_RAGGED_TENSOR_3D_VALUES = [1, 2, 3, 4, 5, 6, 7, 8, 9] + + +def int32array(values): + return np.array(values, dtype=np.int32) + + +@test_util.run_all_in_graph_and_eager_modes +class RowPartitionTest(test_util.TensorFlowTestCase, parameterized.TestCase): + longMessage = True # Property in unittest.Testcase. pylint: disable=invalid-name + + #============================================================================= + # RaggedTensor class docstring examples + #============================================================================= + + def testClassDocStringExamples(self): + # From section: "Component Tensors" + rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) + self.assertAllEqual(rt.row_splits, [0, 4, 4, 7, 8, 8]) + del rt + + # From section: "Alternative Row-Partitioning Schemes" + rt1 = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) + rt2 = RowPartition.from_row_lengths(row_lengths=[4, 0, 3, 1, 0]) + rt3 = RowPartition.from_value_rowids( + value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5) + rt4 = RowPartition.from_row_starts(row_starts=[0, 4, 4, 7, 8], nvals=8) + rt5 = RowPartition.from_row_limits(row_limits=[4, 4, 7, 8, 8]) + for rt in (rt1, rt2, rt3, rt4, rt5): + self.assertAllEqual(rt.row_splits, [0, 4, 4, 7, 8, 8]) + del rt1, rt2, rt3, rt4, rt5 + + # From section: "Multiple Ragged Dimensions" + inner_rt = RowPartition.from_row_splits(row_splits=[0, 4, 4, 7, 8, 8]) + outer_rt = RowPartition.from_row_splits(row_splits=[0, 3, 3, 5]) + del inner_rt, outer_rt + + #============================================================================= + # RaggedTensor Constructor (private) + #============================================================================= + + def testRaggedTensorConstruction(self): + row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + rt = RowPartition(row_splits=row_splits, internal=True) + self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7]) + + def testRaggedTensorConstructionErrors(self): + row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + + with self.assertRaisesRegexp(ValueError, + 'RaggedTensor constructor is private'): + RowPartition(row_splits=row_splits) + + with self.assertRaisesRegexp(TypeError, + 'Row-partitioning argument must be a Tensor'): + RowPartition(row_splits=[0, 2, 2, 5, 6, 7], internal=True) + + with self.assertRaisesRegexp(ValueError, + r'Shape \(6, 1\) must have rank 1'): + RowPartition( + row_splits=array_ops.expand_dims(row_splits, 1), internal=True) + + with self.assertRaisesRegexp(TypeError, + 'Cached value must be a Tensor or None.'): + RowPartition( + row_splits=row_splits, cached_row_lengths=[2, 3, 4], internal=True) + + #============================================================================= + # RaggedTensor Factory Ops + #============================================================================= + + def testFromValueRowIdsWithDerivedNRows(self): + # nrows is known at graph creation time. + value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + # TODO(martinz): add nrows + rt = RowPartition.from_value_rowids(value_rowids, validate=False) + self.assertEqual(rt.dtype, dtypes.int64) + + rt_row_splits = rt.row_splits + rt_value_rowids = rt.value_rowids() + rt_nrows = rt.nrows() + + self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids + self.assertAllEqual(rt_value_rowids, value_rowids) + self.assertAllEqual(rt_nrows, 5) + self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7]) + + def testFromValueRowIdsWithDerivedNRowsDynamic(self): + # nrows is not known at graph creation time. + value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None) + + rt = RowPartition.from_value_rowids(value_rowids, validate=False) + + rt_value_rowids = rt.value_rowids() + rt_nrows = rt.nrows() + + self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids + self.assertAllEqual(rt_value_rowids, value_rowids) + self.assertAllEqual(rt_nrows, 5) + + def testFromValueRowIdsWithExplicitNRows(self): + value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + nrows = constant_op.constant(7, dtypes.int64) + + rt = RowPartition.from_value_rowids(value_rowids, nrows, validate=False) + + rt_value_rowids = rt.value_rowids() + rt_nrows = rt.nrows() + rt_row_splits = rt.row_splits + + self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids + self.assertIs(rt_nrows, nrows) # cached_nrows + self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7, 7, 7]) + + def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self): + value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + nrows = constant_op.constant(5, dtypes.int64) + + rt = RowPartition.from_value_rowids(value_rowids, nrows, validate=False) + + rt_value_rowids = rt.value_rowids() + rt_nrows = rt.nrows() + rt_row_splits = rt.row_splits + + self.assertIs(rt_value_rowids, value_rowids) # cached_value_rowids + self.assertIs(rt_nrows, nrows) # cached_nrows + self.assertAllEqual(rt_value_rowids, value_rowids) + self.assertAllEqual(rt_nrows, nrows) + self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7]) + + def testFromValueRowIdsWithEmptyValues(self): + rt = RowPartition.from_value_rowids([]) + rt_nrows = rt.nrows() + self.assertEqual(rt.dtype, dtypes.int64) + self.assertEqual(rt.value_rowids().shape.as_list(), [0]) + self.assertAllEqual(rt_nrows, 0) + + def testFromRowSplits(self): + row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + + rt = RowPartition.from_row_splits(row_splits, validate=False) + self.assertEqual(rt.dtype, dtypes.int64) + + rt_row_splits = rt.row_splits + rt_nrows = rt.nrows() + + self.assertIs(rt_row_splits, row_splits) + self.assertAllEqual(rt_nrows, 5) + + def testFromRowSplitsWithDifferentSplitTypes(self): + splits1 = [0, 2, 2, 5, 6, 7] + splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64) + splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32) + splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64) + splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32) + rt1 = RowPartition.from_row_splits(splits1) + rt2 = RowPartition.from_row_splits(splits2) + rt3 = RowPartition.from_row_splits(splits3) + rt4 = RowPartition.from_row_splits(splits4) + rt5 = RowPartition.from_row_splits(splits5) + self.assertEqual(rt1.row_splits.dtype, dtypes.int64) + self.assertEqual(rt2.row_splits.dtype, dtypes.int64) + self.assertEqual(rt3.row_splits.dtype, dtypes.int32) + self.assertEqual(rt4.row_splits.dtype, dtypes.int64) + self.assertEqual(rt5.row_splits.dtype, dtypes.int32) + + def testFromRowSplitsWithEmptySplits(self): + err_msg = 'row_splits tensor may not be empty' + with self.assertRaisesRegexp(ValueError, err_msg): + RowPartition.from_row_splits([], []) + + def testFromRowStarts(self): + nvals = constant_op.constant(7) + row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64) + + rt = RowPartition.from_row_starts(row_starts, nvals, validate=False) + self.assertEqual(rt.dtype, dtypes.int64) + + rt_row_starts = rt.row_starts() + rt_row_splits = rt.row_splits + rt_nrows = rt.nrows() + + self.assertAllEqual(rt_nrows, 5) + self.assertAllEqual(rt_row_starts, row_starts) + self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7]) + + def testFromRowLimits(self): + row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64) + + rt = RowPartition.from_row_limits(row_limits, validate=False) + self.assertEqual(rt.dtype, dtypes.int64) + + rt_row_limits = rt.row_limits() + rt_row_splits = rt.row_splits + rt_nrows = rt.nrows() + + self.assertAllEqual(rt_nrows, 5) + self.assertAllEqual(rt_row_limits, row_limits) + self.assertAllEqual(rt_row_splits, [0, 2, 2, 5, 6, 7]) + + def testFromRowLengths(self): + row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64) + + rt = RowPartition.from_row_lengths(row_lengths, validate=False) + self.assertEqual(rt.dtype, dtypes.int64) + + rt_row_lengths = rt.row_lengths() + rt_nrows = rt.nrows() + + self.assertIs(rt_row_lengths, row_lengths) # cached_nrows + self.assertAllEqual(rt_nrows, 5) + self.assertAllEqual(rt_row_lengths, row_lengths) + + def testFromUniformRowLength(self): + nvals = 16 + a1 = RowPartition.from_uniform_row_length(nvals, 2) + self.assertAllEqual(a1.uniform_row_length(), 2) + self.assertAllEqual(a1.nrows(), 8) + + def testFromUniformRowLengthWithEmptyValues(self): + a = RowPartition.from_uniform_row_length( + nvals=0, uniform_row_length=0, nrows=10) + self.assertEqual(self.evaluate(a.nvals()), 0) + self.assertEqual(self.evaluate(a.nrows()), 10) + + def testFromUniformRowLengthWithPlaceholders1(self): + nvals = array_ops.placeholder_with_default( + constant_op.constant(6, dtype=dtypes.int64), None) + rt1 = RowPartition.from_uniform_row_length(nvals, 3) + const_nvals1 = self.evaluate(rt1.nvals()) + self.assertEqual(const_nvals1, 6) + + def testFromUniformRowLengthWithPlaceholders2(self): + nvals = array_ops.placeholder_with_default(6, None) + ph_rowlen = array_ops.placeholder_with_default(3, None) + rt2 = RowPartition.from_uniform_row_length(nvals, ph_rowlen) + const_nvals2 = self.evaluate(rt2.nvals()) + self.assertEqual(const_nvals2, 6) + + def testFromValueRowIdsWithBadNRows(self): + value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64) + nrows = constant_op.constant(5, dtypes.int64) + + with self.assertRaisesRegexp(ValueError, r'Expected nrows >= 0; got -2'): + RowPartition.from_value_rowids( + value_rowids=array_ops.placeholder_with_default(value_rowids, None), + nrows=-2) + + with self.assertRaisesRegexp( + ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, ' + r'value_rowids\[-1\]=4'): + RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=2) + + with self.assertRaisesRegexp( + ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, ' + r'value_rowids\[-1\]=4'): + RowPartition.from_value_rowids(value_rowids=value_rowids, nrows=4) + + with self.assertRaisesRegexp(ValueError, + r'Shape \(7, 1\) must have rank 1'): + RowPartition.from_value_rowids( + value_rowids=array_ops.expand_dims(value_rowids, 1), nrows=nrows) + + with self.assertRaisesRegexp(ValueError, r'Shape \(1,\) must have rank 0'): + RowPartition.from_value_rowids( + value_rowids=value_rowids, nrows=array_ops.expand_dims(nrows, 0)) + + #============================================================================= + # RowPartition.__str__ + #============================================================================= + def testRowPartitionStr(self): + row_splits = [0, 2, 5, 6, 6, 7] + rt = RowPartition.from_row_splits(row_splits, validate=False) + splits_type = 'int64' + if context.executing_eagerly(): + expected_repr = ('tf.RowPartition(row_splits=tf.Tensor([0 2 5 6 6 7], ' + 'shape=(6,), dtype=int64))') + else: + expected_repr = ('tf.RowPartition(row_splits=' + 'Tensor("RowPartitionFromRowSplits/row_splits:0", ' + 'shape=(6,), dtype={}))').format(splits_type) + self.assertEqual(repr(rt), expected_repr) + self.assertEqual(str(rt), expected_repr) + + @parameterized.parameters([ + # from_value_rowids + { + 'descr': 'bad rank for value_rowids', + 'factory': RowPartition.from_value_rowids, + 'value_rowids': [[1, 2], [3, 4]], + 'nrows': 10 + }, + { + 'descr': 'bad rank for nrows', + 'factory': RowPartition.from_value_rowids, + 'value_rowids': [1, 2, 3, 4], + 'nrows': [10] + }, + { + 'descr': 'negative value_rowid', + 'factory': RowPartition.from_value_rowids, + 'value_rowids': [-5, 2, 3, 4], + 'nrows': 10 + }, + { + 'descr': 'non-monotonic-increasing value_rowid', + 'factory': RowPartition.from_value_rowids, + 'value_rowids': [4, 3, 2, 1], + 'nrows': 10 + }, + { + 'descr': 'value_rowid > nrows', + 'factory': RowPartition.from_value_rowids, + 'value_rowids': [1, 2, 3, 4], + 'nrows': 2 + }, + + # from_row_splits + { + 'descr': 'bad rank for row_splits', + 'factory': RowPartition.from_row_splits, + 'row_splits': [[1, 2], [3, 4]] + }, + { + 'descr': 'row_splits[0] != 0', + 'factory': RowPartition.from_row_splits, + 'row_splits': [2, 3, 4] + }, + { + 'descr': 'non-monotonic-increasing row_splits', + 'factory': RowPartition.from_row_splits, + 'row_splits': [0, 3, 2, 4] + }, + + # from_row_lengths + { + 'descr': 'bad rank for row_lengths', + 'factory': RowPartition.from_row_lengths, + 'row_lengths': [[1, 2], [1, 0]] + }, + { + 'descr': 'negatve row_lengths', + 'factory': RowPartition.from_row_lengths, + 'row_lengths': [3, -1, 2] + }, + + # from_row_starts + { + 'descr': 'bad rank for row_starts', + 'factory': RowPartition.from_row_starts, + 'nvals': 2, + 'row_starts': [[1, 2], [3, 4]] + }, + { + 'descr': 'row_starts[0] != 0', + 'factory': RowPartition.from_row_starts, + 'nvals': 5, + 'row_starts': [2, 3, 4] + }, + { + 'descr': 'non-monotonic-increasing row_starts', + 'factory': RowPartition.from_row_starts, + 'nvals': 4, + 'row_starts': [0, 3, 2, 4] + }, + { + 'descr': 'row_starts[0] > nvals', + 'factory': RowPartition.from_row_starts, + 'nvals': 4, + 'row_starts': [0, 2, 3, 5] + }, + + # from_row_limits + { + 'descr': 'bad rank for row_limits', + 'factory': RowPartition.from_row_limits, + 'row_limits': [[1, 2], [3, 4]] + }, + { + 'descr': 'row_limits[0] < 0', + 'factory': RowPartition.from_row_limits, + 'row_limits': [-1, 3, 4] + }, + { + 'descr': 'non-monotonic-increasing row_limits', + 'factory': RowPartition.from_row_limits, + 'row_limits': [0, 3, 2, 4] + }, + + # from_uniform_row_length + { + 'descr': 'rowlen * nrows != nvals (1)', + 'factory': RowPartition.from_uniform_row_length, + 'nvals': 5, + 'uniform_row_length': 3 + }, + { + 'descr': 'rowlen * nrows != nvals (2)', + 'factory': RowPartition.from_uniform_row_length, + 'nvals': 5, + 'uniform_row_length': 6 + }, + { + 'descr': 'rowlen * nrows != nvals (3)', + 'factory': RowPartition.from_uniform_row_length, + 'nvals': 6, + 'uniform_row_length': 3, + 'nrows': 3 + }, + { + 'descr': 'rowlen must be a scalar', + 'factory': RowPartition.from_uniform_row_length, + 'nvals': 4, + 'uniform_row_length': [2] + }, + { + 'descr': 'rowlen must be nonnegative', + 'factory': RowPartition.from_uniform_row_length, + 'nvals': 4, + 'uniform_row_length': -1 + }, + ]) + def testFactoryValidation(self, descr, factory, **kwargs): + # When input tensors have shape information, some of these errors will be + # detected statically. + with self.assertRaises((errors.InvalidArgumentError, ValueError)): + partition = factory(**kwargs) + self.evaluate(partition.row_splits) + + # Remove shape information (by wrapping tensors in placeholders), and check + # that we detect the errors when the graph is run. + if not context.executing_eagerly(): + + def wrap_arg(v): + return array_ops.placeholder_with_default( + constant_op.constant(v, dtype=dtypes.int64), + tensor_shape.TensorShape(None)) + + kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items()) + + with self.assertRaises(errors.InvalidArgumentError): + partition = factory(**kwargs) + self.evaluate(partition.row_splits) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt index c64909d45f5..f65245a988c 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-ragged-tensor.pbtxt @@ -19,6 +19,10 @@ tf_class { name: "ragged_rank" mtype: "" } + member { + name: "row_partition" + mtype: "" + } member { name: "row_splits" mtype: "" @@ -37,7 +41,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'values\', \'row_splits\', \'cached_row_lengths\', \'cached_value_rowids\', \'cached_nrows\', \'internal\', \'uniform_row_length\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'values\', \'row_partition\', \'internal\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "bounding_shape" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt index c64909d45f5..f65245a988c 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.-ragged-tensor.pbtxt @@ -19,6 +19,10 @@ tf_class { name: "ragged_rank" mtype: "" } + member { + name: "row_partition" + mtype: "" + } member { name: "row_splits" mtype: "" @@ -37,7 +41,7 @@ tf_class { } member_method { name: "__init__" - argspec: "args=[\'self\', \'values\', \'row_splits\', \'cached_row_lengths\', \'cached_value_rowids\', \'cached_nrows\', \'internal\', \'uniform_row_length\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], " + argspec: "args=[\'self\', \'values\', \'row_partition\', \'internal\'], varargs=None, keywords=None, defaults=[\'False\'], " } member_method { name: "bounding_shape"