Make shape, dtype, and ragged_rank public in tf.RaggedTensorSpec.

PiperOrigin-RevId: 319112320
Change-Id: I11948b2d437ea68622117df54fda64b2045a10fe
This commit is contained in:
A. Unique TensorFlower 2020-06-30 15:34:12 -07:00 committed by TensorFlower Gardener
parent e8ecb2c200
commit c0bb8efaf1
4 changed files with 159 additions and 38 deletions

View File

@ -827,11 +827,22 @@ class RaggedTensor(composite_tensor.CompositeTensor,
@property
def ragged_rank(self):
"""The number of ragged dimensions in this ragged tensor.
"""The number of times the RaggedTensor's flat_values is partitioned.
Examples:
>>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
>>> values.ragged_rank
1
>>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2)
>>> rt.ragged_rank
2
Returns:
A Python `int` indicating the number of ragged dimensions in this ragged
tensor. The outermost dimension is not considered ragged.
A Python `int` indicating the number of times the underlying `flat_values`
Tensor has been partitioned to add a new dimension.
I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
"""
values_is_ragged = isinstance(self._values, RaggedTensor)
return self._values.ragged_rank + 1 if values_is_ragged else 1
@ -2119,6 +2130,80 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype"]
@property
def dtype(self):
"""The `tf.dtypes.DType` specified by this type for the RaggedTensor.
Examples:
>>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string)
>>> tf.type_spec_from_value(rt).dtype
tf.string
Returns:
A `tf.dtypes.DType` of the values in the RaggedTensor.
"""
return self._dtype
@property
def shape(self):
"""The statically known shape of the RaggedTensor.
Examples:
>>> rt = tf.ragged.constant([[0], [1, 2]])
>>> tf.type_spec_from_value(rt).shape
TensorShape([2, None])
>>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1)
>>> tf.type_spec_from_value(rt).shape
TensorShape([2, None, 2])
Returns:
A `tf.TensorShape` containing the statically known shape of the
RaggedTensor. Ragged dimensions have a size of `None`.
"""
return self._shape
@property
def ragged_rank(self):
"""The number of times the RaggedTensor's flat_values is partitioned.
Defaults to `shape.ndims - 1`.
Examples:
>>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
>>> tf.type_spec_from_value(values).ragged_rank
1
>>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
>>> tf.type_spec_from_value(rt1).ragged_rank
2
Returns:
A Python `int` indicating the number of times the underlying `flat_values`
Tensor has been partitioned to add a new dimension.
I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
"""
return self._ragged_rank
@property
def row_splits_dtype(self):
"""The `tf.dtypes.DType` of the the RaggedTensor's `row_splits`.
Examples:
>>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64)
>>> tf.type_spec_from_value(rt).row_splits_dtype
tf.int64
Returns:
A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One
of `tf.int32` or `tf.int64`.
"""
return self._row_splits_dtype
@property
def value_type(self):
return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
@ -2134,8 +2219,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
shape: The shape of the RaggedTensor, or `None` to allow any shape. If a
shape is specified, then all ragged dimensions must have size `None`.
dtype: `tf.DType` of values in the RaggedTensor.
ragged_rank: Python integer, the ragged rank of the RaggedTensor to be
described. Defaults to `shape.ndims - 1`.
ragged_rank: Python integer, the number of times the RaggedTensor's
flat_values is partitioned. Defaults to `shape.ndims - 1`.
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
of `tf.int32` or `tf.int64`.
"""

View File

@ -151,18 +151,18 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
rp = RowPartition.from_row_splits(row_splits)
with self.assertRaisesRegexp(ValueError,
with self.assertRaisesRegex(ValueError,
'RaggedTensor constructor is private'):
RaggedTensor(values=values, row_partition=rp)
with self.assertRaisesRegexp(TypeError,
with self.assertRaisesRegex(TypeError,
'values must be a Tensor or RaggedTensor'):
RaggedTensor(values=range(7), row_partition=rp, internal=True)
with self.assertRaisesRegexp(TypeError,
with self.assertRaisesRegex(TypeError,
'row_partition must be a RowPartition'):
RaggedTensor(values=values, row_partition=[0, 2, 2, 5, 6, 7],
internal=True)
RaggedTensor(
values=values, row_partition=[0, 2, 2, 5, 6, 7], internal=True)
#=============================================================================
# RaggedTensor Factory Ops
@ -308,7 +308,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testFromRowSplitsWithEmptySplits(self):
err_msg = 'row_splits tensor may not be empty'
with self.assertRaisesRegexp(ValueError, err_msg):
with self.assertRaisesRegex(ValueError, err_msg):
RaggedTensor.from_row_splits([], [])
def testFromRowStarts(self):
@ -511,17 +511,17 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
]
nrows = [constant_op.constant(6, dtypes.int64)]
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError, 'nested_nrows must have the same '
'length as nested_value_rowids'):
RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, nrows)
def testFromNestedValueRowIdsWithNonListInput(self):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
TypeError, 'nested_value_rowids must be a list of Tensors'):
RaggedTensor.from_nested_value_rowids(
[1, 2, 3], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
with self.assertRaisesRegexp(TypeError,
with self.assertRaisesRegex(TypeError,
'nested_nrows must be a list of Tensors'):
RaggedTensor.from_nested_value_rowids([1, 2, 3], [[0, 1, 2], [0, 1, 2]],
constant_op.constant([3, 3]))
@ -578,7 +578,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
def testFromNestedRowSplitsWithNonListInput(self):
with self.assertRaisesRegexp(TypeError,
with self.assertRaisesRegex(TypeError,
'nested_row_splits must be a list of Tensors'):
RaggedTensor.from_nested_row_splits(
[1, 2], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
@ -588,32 +588,31 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
nrows = constant_op.constant(5, dtypes.int64)
with self.assertRaisesRegexp(ValueError, r'Expected nrows >= 0; got -2'):
with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'):
RaggedTensor.from_value_rowids(
values=values,
value_rowids=array_ops.placeholder_with_default(value_rowids, None),
nrows=-2)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
r'value_rowids\[-1\]=4'):
RaggedTensor.from_value_rowids(
values=values, value_rowids=value_rowids, nrows=2)
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
r'value_rowids\[-1\]=4'):
RaggedTensor.from_value_rowids(
values=values, value_rowids=value_rowids, nrows=4)
with self.assertRaisesRegexp(ValueError,
r'Shape \(7, 1\) must have rank 1'):
with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'):
RaggedTensor.from_value_rowids(
values=values,
value_rowids=array_ops.expand_dims(value_rowids, 1),
nrows=nrows)
with self.assertRaisesRegexp(ValueError, r'Shape \(1,\) must have rank 0'):
with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'):
RaggedTensor.from_value_rowids(
values=values,
value_rowids=value_rowids,
@ -632,9 +631,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
values = constant_op.constant([1, 2, 3], dtypes.int64)
with ops.Graph().as_default():
splits = constant_op.constant([0, 2, 3], dtypes.int64)
self.assertRaisesRegexp(ValueError,
'.* must be from the same graph as .*',
RaggedTensor.from_row_splits, values, splits)
with self.assertRaisesRegex(ValueError,
'.* must be from the same graph as .*'):
RaggedTensor.from_row_splits(values, splits)
#=============================================================================
# Ragged Value & Row-Partitioning Tensor Accessors
@ -755,7 +754,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
if not context.executing_eagerly():
rt5 = RaggedTensor.from_row_splits(
array_ops.placeholder(dtype=dtypes.string), [0, 2, 3, 5])
self.assertEqual(rt5.shape.ndims, None)
self.assertIsNone(rt5.shape.ndims)
rt6 = RaggedTensor.from_row_splits(
[1, 2, 3], array_ops.placeholder(dtype=dtypes.int64))
@ -1419,7 +1418,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
batched_variant = rt._to_variant(batched_input=True)
nested_batched_variant = array_ops.reshape(batched_variant, [2, 2])
with self.assertRaisesRegexp(ValueError,
with self.assertRaisesRegex(ValueError,
'output_ragged_rank must be equal to'):
RaggedTensor._from_variant(
nested_batched_variant,
@ -1481,7 +1480,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertNumpyObjectTensorsRecursivelyEqual(
expected, actual, 'Expected %r, got %r' % (expected, actual))
else:
with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'):
with self.assertRaisesRegex(ValueError, 'only supported in eager mode'):
rt.numpy()
@parameterized.parameters([
@ -1563,22 +1562,27 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
def testConstruction(self):
spec1 = RaggedTensorSpec(ragged_rank=1)
self.assertEqual(spec1._shape.rank, None)
self.assertIsNone(spec1._shape.rank)
self.assertEqual(spec1._dtype, dtypes.float32)
self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
self.assertEqual(spec1._ragged_rank, 1)
self.assertIsNone(spec1.shape.rank)
self.assertEqual(spec1.dtype, dtypes.float32)
self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
self.assertEqual(spec1.ragged_rank, 1)
spec2 = RaggedTensorSpec(shape=[None, None, None])
self.assertEqual(spec2._shape.as_list(), [None, None, None])
self.assertEqual(spec2._dtype, dtypes.float32)
self.assertEqual(spec2._row_splits_dtype, dtypes.int64)
self.assertEqual(spec2._ragged_rank, 2)
with self.assertRaisesRegexp(ValueError, 'Must specify ragged_rank'):
with self.assertRaisesRegex(ValueError, 'Must specify ragged_rank'):
RaggedTensorSpec()
with self.assertRaisesRegexp(TypeError, 'ragged_rank must be an int'):
with self.assertRaisesRegex(TypeError, 'ragged_rank must be an int'):
RaggedTensorSpec(ragged_rank=constant_op.constant(1))
with self.assertRaisesRegexp(ValueError,
with self.assertRaisesRegex(ValueError,
'ragged_rank must be less than rank'):
RaggedTensorSpec(ragged_rank=2, shape=[None, None])

View File

@ -4,6 +4,22 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "ragged_rank"
mtype: "<type \'property\'>"
}
member {
name: "row_splits_dtype"
mtype: "<type \'property\'>"
}
member {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"

View File

@ -4,6 +4,22 @@ tf_class {
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
is_instance: "<type \'object\'>"
member {
name: "dtype"
mtype: "<type \'property\'>"
}
member {
name: "ragged_rank"
mtype: "<type \'property\'>"
}
member {
name: "row_splits_dtype"
mtype: "<type \'property\'>"
}
member {
name: "shape"
mtype: "<type \'property\'>"
}
member {
name: "value_type"
mtype: "<type \'property\'>"