Make shape, dtype, and ragged_rank public in tf.RaggedTensorSpec.
PiperOrigin-RevId: 319112320 Change-Id: I11948b2d437ea68622117df54fda64b2045a10fe
This commit is contained in:
parent
e8ecb2c200
commit
c0bb8efaf1
tensorflow
python/ops/ragged
tools/api/golden
@ -827,11 +827,22 @@ class RaggedTensor(composite_tensor.CompositeTensor,
|
||||
|
||||
@property
|
||||
def ragged_rank(self):
|
||||
"""The number of ragged dimensions in this ragged tensor.
|
||||
"""The number of times the RaggedTensor's flat_values is partitioned.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
|
||||
>>> values.ragged_rank
|
||||
1
|
||||
|
||||
>>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2)
|
||||
>>> rt.ragged_rank
|
||||
2
|
||||
|
||||
Returns:
|
||||
A Python `int` indicating the number of ragged dimensions in this ragged
|
||||
tensor. The outermost dimension is not considered ragged.
|
||||
A Python `int` indicating the number of times the underlying `flat_values`
|
||||
Tensor has been partitioned to add a new dimension.
|
||||
I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
|
||||
"""
|
||||
values_is_ragged = isinstance(self._values, RaggedTensor)
|
||||
return self._values.ragged_rank + 1 if values_is_ragged else 1
|
||||
@ -2119,6 +2130,80 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
|
||||
__slots__ = ["_shape", "_dtype", "_ragged_rank", "_row_splits_dtype"]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The `tf.dtypes.DType` specified by this type for the RaggedTensor.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string)
|
||||
>>> tf.type_spec_from_value(rt).dtype
|
||||
tf.string
|
||||
|
||||
Returns:
|
||||
A `tf.dtypes.DType` of the values in the RaggedTensor.
|
||||
"""
|
||||
return self._dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
"""The statically known shape of the RaggedTensor.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> rt = tf.ragged.constant([[0], [1, 2]])
|
||||
>>> tf.type_spec_from_value(rt).shape
|
||||
TensorShape([2, None])
|
||||
|
||||
>>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1)
|
||||
>>> tf.type_spec_from_value(rt).shape
|
||||
TensorShape([2, None, 2])
|
||||
|
||||
Returns:
|
||||
A `tf.TensorShape` containing the statically known shape of the
|
||||
RaggedTensor. Ragged dimensions have a size of `None`.
|
||||
"""
|
||||
return self._shape
|
||||
|
||||
@property
|
||||
def ragged_rank(self):
|
||||
"""The number of times the RaggedTensor's flat_values is partitioned.
|
||||
|
||||
Defaults to `shape.ndims - 1`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
|
||||
>>> tf.type_spec_from_value(values).ragged_rank
|
||||
1
|
||||
|
||||
>>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
|
||||
>>> tf.type_spec_from_value(rt1).ragged_rank
|
||||
2
|
||||
|
||||
Returns:
|
||||
A Python `int` indicating the number of times the underlying `flat_values`
|
||||
Tensor has been partitioned to add a new dimension.
|
||||
I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
|
||||
"""
|
||||
return self._ragged_rank
|
||||
|
||||
@property
|
||||
def row_splits_dtype(self):
|
||||
"""The `tf.dtypes.DType` of the the RaggedTensor's `row_splits`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64)
|
||||
>>> tf.type_spec_from_value(rt).row_splits_dtype
|
||||
tf.int64
|
||||
|
||||
Returns:
|
||||
A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One
|
||||
of `tf.int32` or `tf.int64`.
|
||||
"""
|
||||
return self._row_splits_dtype
|
||||
|
||||
@property
|
||||
def value_type(self):
|
||||
return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
|
||||
@ -2134,8 +2219,8 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
shape: The shape of the RaggedTensor, or `None` to allow any shape. If a
|
||||
shape is specified, then all ragged dimensions must have size `None`.
|
||||
dtype: `tf.DType` of values in the RaggedTensor.
|
||||
ragged_rank: Python integer, the ragged rank of the RaggedTensor to be
|
||||
described. Defaults to `shape.ndims - 1`.
|
||||
ragged_rank: Python integer, the number of times the RaggedTensor's
|
||||
flat_values is partitioned. Defaults to `shape.ndims - 1`.
|
||||
row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
|
||||
of `tf.int32` or `tf.int64`.
|
||||
"""
|
||||
|
@ -151,18 +151,18 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
|
||||
rp = RowPartition.from_row_splits(row_splits)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'RaggedTensor constructor is private'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'RaggedTensor constructor is private'):
|
||||
RaggedTensor(values=values, row_partition=rp)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'values must be a Tensor or RaggedTensor'):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'values must be a Tensor or RaggedTensor'):
|
||||
RaggedTensor(values=range(7), row_partition=rp, internal=True)
|
||||
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'row_partition must be a RowPartition'):
|
||||
RaggedTensor(values=values, row_partition=[0, 2, 2, 5, 6, 7],
|
||||
internal=True)
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'row_partition must be a RowPartition'):
|
||||
RaggedTensor(
|
||||
values=values, row_partition=[0, 2, 2, 5, 6, 7], internal=True)
|
||||
|
||||
#=============================================================================
|
||||
# RaggedTensor Factory Ops
|
||||
@ -308,7 +308,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
|
||||
def testFromRowSplitsWithEmptySplits(self):
|
||||
err_msg = 'row_splits tensor may not be empty'
|
||||
with self.assertRaisesRegexp(ValueError, err_msg):
|
||||
with self.assertRaisesRegex(ValueError, err_msg):
|
||||
RaggedTensor.from_row_splits([], [])
|
||||
|
||||
def testFromRowStarts(self):
|
||||
@ -511,18 +511,18 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
]
|
||||
nrows = [constant_op.constant(6, dtypes.int64)]
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, 'nested_nrows must have the same '
|
||||
'length as nested_value_rowids'):
|
||||
RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, nrows)
|
||||
|
||||
def testFromNestedValueRowIdsWithNonListInput(self):
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, 'nested_value_rowids must be a list of Tensors'):
|
||||
RaggedTensor.from_nested_value_rowids(
|
||||
[1, 2, 3], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'nested_nrows must be a list of Tensors'):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'nested_nrows must be a list of Tensors'):
|
||||
RaggedTensor.from_nested_value_rowids([1, 2, 3], [[0, 1, 2], [0, 1, 2]],
|
||||
constant_op.constant([3, 3]))
|
||||
|
||||
@ -578,8 +578,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
|
||||
|
||||
def testFromNestedRowSplitsWithNonListInput(self):
|
||||
with self.assertRaisesRegexp(TypeError,
|
||||
'nested_row_splits must be a list of Tensors'):
|
||||
with self.assertRaisesRegex(TypeError,
|
||||
'nested_row_splits must be a list of Tensors'):
|
||||
RaggedTensor.from_nested_row_splits(
|
||||
[1, 2], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
|
||||
|
||||
@ -588,32 +588,31 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
|
||||
nrows = constant_op.constant(5, dtypes.int64)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r'Expected nrows >= 0; got -2'):
|
||||
with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'):
|
||||
RaggedTensor.from_value_rowids(
|
||||
values=values,
|
||||
value_rowids=array_ops.placeholder_with_default(value_rowids, None),
|
||||
nrows=-2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
|
||||
r'value_rowids\[-1\]=4'):
|
||||
RaggedTensor.from_value_rowids(
|
||||
values=values, value_rowids=value_rowids, nrows=2)
|
||||
|
||||
with self.assertRaisesRegexp(
|
||||
with self.assertRaisesRegex(
|
||||
ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
|
||||
r'value_rowids\[-1\]=4'):
|
||||
RaggedTensor.from_value_rowids(
|
||||
values=values, value_rowids=value_rowids, nrows=4)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
r'Shape \(7, 1\) must have rank 1'):
|
||||
with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'):
|
||||
RaggedTensor.from_value_rowids(
|
||||
values=values,
|
||||
value_rowids=array_ops.expand_dims(value_rowids, 1),
|
||||
nrows=nrows)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, r'Shape \(1,\) must have rank 0'):
|
||||
with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'):
|
||||
RaggedTensor.from_value_rowids(
|
||||
values=values,
|
||||
value_rowids=value_rowids,
|
||||
@ -632,9 +631,9 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
values = constant_op.constant([1, 2, 3], dtypes.int64)
|
||||
with ops.Graph().as_default():
|
||||
splits = constant_op.constant([0, 2, 3], dtypes.int64)
|
||||
self.assertRaisesRegexp(ValueError,
|
||||
'.* must be from the same graph as .*',
|
||||
RaggedTensor.from_row_splits, values, splits)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'.* must be from the same graph as .*'):
|
||||
RaggedTensor.from_row_splits(values, splits)
|
||||
|
||||
#=============================================================================
|
||||
# Ragged Value & Row-Partitioning Tensor Accessors
|
||||
@ -755,7 +754,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
if not context.executing_eagerly():
|
||||
rt5 = RaggedTensor.from_row_splits(
|
||||
array_ops.placeholder(dtype=dtypes.string), [0, 2, 3, 5])
|
||||
self.assertEqual(rt5.shape.ndims, None)
|
||||
self.assertIsNone(rt5.shape.ndims)
|
||||
|
||||
rt6 = RaggedTensor.from_row_splits(
|
||||
[1, 2, 3], array_ops.placeholder(dtype=dtypes.int64))
|
||||
@ -1419,8 +1418,8 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
|
||||
batched_variant = rt._to_variant(batched_input=True)
|
||||
nested_batched_variant = array_ops.reshape(batched_variant, [2, 2])
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'output_ragged_rank must be equal to'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'output_ragged_rank must be equal to'):
|
||||
RaggedTensor._from_variant(
|
||||
nested_batched_variant,
|
||||
dtype=dtypes.int32,
|
||||
@ -1481,7 +1480,7 @@ class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
|
||||
self.assertNumpyObjectTensorsRecursivelyEqual(
|
||||
expected, actual, 'Expected %r, got %r' % (expected, actual))
|
||||
else:
|
||||
with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'):
|
||||
with self.assertRaisesRegex(ValueError, 'only supported in eager mode'):
|
||||
rt.numpy()
|
||||
|
||||
@parameterized.parameters([
|
||||
@ -1563,23 +1562,28 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
|
||||
def testConstruction(self):
|
||||
spec1 = RaggedTensorSpec(ragged_rank=1)
|
||||
self.assertEqual(spec1._shape.rank, None)
|
||||
self.assertIsNone(spec1._shape.rank)
|
||||
self.assertEqual(spec1._dtype, dtypes.float32)
|
||||
self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
|
||||
self.assertEqual(spec1._ragged_rank, 1)
|
||||
|
||||
self.assertIsNone(spec1.shape.rank)
|
||||
self.assertEqual(spec1.dtype, dtypes.float32)
|
||||
self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
|
||||
self.assertEqual(spec1.ragged_rank, 1)
|
||||
|
||||
spec2 = RaggedTensorSpec(shape=[None, None, None])
|
||||
self.assertEqual(spec2._shape.as_list(), [None, None, None])
|
||||
self.assertEqual(spec2._dtype, dtypes.float32)
|
||||
self.assertEqual(spec2._row_splits_dtype, dtypes.int64)
|
||||
self.assertEqual(spec2._ragged_rank, 2)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'Must specify ragged_rank'):
|
||||
with self.assertRaisesRegex(ValueError, 'Must specify ragged_rank'):
|
||||
RaggedTensorSpec()
|
||||
with self.assertRaisesRegexp(TypeError, 'ragged_rank must be an int'):
|
||||
with self.assertRaisesRegex(TypeError, 'ragged_rank must be an int'):
|
||||
RaggedTensorSpec(ragged_rank=constant_op.constant(1))
|
||||
with self.assertRaisesRegexp(ValueError,
|
||||
'ragged_rank must be less than rank'):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
'ragged_rank must be less than rank'):
|
||||
RaggedTensorSpec(ragged_rank=2, shape=[None, None])
|
||||
|
||||
def testValueType(self):
|
||||
|
@ -4,6 +4,22 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "ragged_rank"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "row_splits_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
@ -4,6 +4,22 @@ tf_class {
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.BatchableTypeSpec\'>"
|
||||
is_instance: "<class \'tensorflow.python.framework.type_spec.TypeSpec\'>"
|
||||
is_instance: "<type \'object\'>"
|
||||
member {
|
||||
name: "dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "ragged_rank"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "row_splits_dtype"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "shape"
|
||||
mtype: "<type \'property\'>"
|
||||
}
|
||||
member {
|
||||
name: "value_type"
|
||||
mtype: "<type \'property\'>"
|
||||
|
Loading…
Reference in New Issue
Block a user