Add RaggedTensor.numpy() -- Returns a ragged tensor as a numpy array.

This makes it easier to use eager Tensors and eager RaggedTensors interchangeably.

The numpy array returned by RaggedTensor encodes ragged dimensions by using a 1D array with dtype=object, where separate objects (nested values or nested arrays) are used for each row.

PiperOrigin-RevId: 293506445
Change-Id: Ieda7484ad05d5cb0e21be7ba2cd934f5f02e6176
This commit is contained in:
Edward Loper 2020-02-05 19:40:19 -08:00 committed by TensorFlower Gardener
parent 0bf1a402b2
commit 66779177f6
4 changed files with 110 additions and 0 deletions

View File

@ -1997,6 +1997,46 @@ class RaggedTensor(composite_tensor.CompositeTensor):
# Eager Execution Mode
#=============================================================================
def numpy(self):
"""Returns a numpy `array` with the values for this `RaggedTensor`.
Requires that this `RaggedTensor` was constructed in eager execution mode.
Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and
`rank=1`, where each element is a single row.
#### Examples
In the following example, the value returned by `RaggedTensor.numpy()`
contains three numpy `array` objects: one for each row (with `rank=1` and
`dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`):
>>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy()
array([array([1, 2, 3]), array([4, 5])], dtype=object)
Uniform dimensions are encoded using multidimensional numpy `array`s. In
the following example, the value returned by `RaggedTensor.numpy()` contains
a single numpy `array` object, with `rank=2` and `dtype=int64`:
>>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy()
array([[1, 2, 3], [4, 5, 6]])
Returns:
A numpy `array`.
"""
if not self._is_eager():
raise ValueError("RaggedTensor.numpy() is only supported in eager mode.")
values = self._values.numpy()
splits = self._row_splits.numpy()
rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
if not rows:
return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype)
# Note: if `rows` have ragged lengths, then they will be stored in a
# np.ndarray with dtype=object and rank=1. If they have uniform lengths,
# they will be combined into a single np.ndarray with dtype=row.dtype and
# rank=row.rank+1.
return np.array(rows)
def to_list(self):
"""Returns a nested Python `list` with the values for this `RaggedTensor`.

View File

@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_math_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.ops.ragged import ragged_tensor_value
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
@ -116,6 +117,10 @@ EXAMPLE_RAGGED_TENSOR_4D_VALUES = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10],
[19, 20]]
def int32array(values):
return np.array(values, dtype=np.int32)
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorTest(test_util.TensorFlowTestCase,
parameterized.TestCase):
@ -1734,6 +1739,63 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
output_ragged_rank=1,
input_ragged_rank=1)
def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg):
"""Check that two numpy arrays are equal.
For arrays with dtype=object, check values recursively to see if a and b
are equal. (c.f. `np.array_equal`, which checks dtype=object values using
object identity.)
Args:
a: A numpy array.
b: A numpy array.
msg: Message to display if a != b.
"""
if isinstance(a, np.ndarray) and a.dtype == object:
self.assertEqual(a.dtype, b.dtype, msg)
self.assertEqual(a.shape, b.shape, msg)
self.assertLen(a, len(b), msg)
for a_val, b_val in zip(a, b):
self.assertNumpyObjectTensorsRecursivelyEqual(a_val, b_val, msg)
else:
self.assertAllEqual(a, b, msg)
@parameterized.named_parameters([
('Shape_2_R',
[[1, 2], [3, 4, 5]],
np.array([int32array([1, 2]), int32array([3, 4, 5])])),
('Shape_2_2',
[[1, 2], [3, 4]],
np.array([[1, 2], [3, 4]])),
('Shape_2_R_2',
[[[1, 2], [3, 4]], [[5, 6]]],
np.array([int32array([[1, 2], [3, 4]]), int32array([[5, 6]])])),
('Shape_3_2_R',
[[[1], []], [[2, 3], [4]], [[], [5, 6, 7]]],
np.array([[int32array([1]), int32array([])],
[int32array([2, 3]), int32array([4])],
[int32array([]), int32array([5, 6, 7])]])),
('Shape_0_R',
ragged_factory_ops.constant_value([], ragged_rank=1, dtype=np.int32),
np.zeros([0, 0], dtype=np.int32)),
('Shape_0_R_2',
ragged_factory_ops.constant_value([], ragged_rank=1,
inner_shape=(2,), dtype=np.int32),
np.zeros([0, 0, 2], dtype=np.int32)),
]) # pyformat: disable
def testRaggedTensorNumpy(self, rt, expected):
if isinstance(rt, list):
rt = ragged_factory_ops.constant(rt, dtype=dtypes.int32)
else:
rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
if context.executing_eagerly():
actual = rt.numpy()
self.assertNumpyObjectTensorsRecursivelyEqual(
expected, actual, 'Expected %r, got %r' % (expected, actual))
else:
with self.assertRaisesRegexp(ValueError, 'only supported in eager mode'):
rt.numpy()
@test_util.run_all_in_graph_and_eager_modes
class RaggedTensorSpecTest(test_util.TensorFlowTestCase,

View File

@ -103,6 +103,10 @@ tf_class {
name: "nrows"
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "numpy"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "row_lengths"
argspec: "args=[\'self\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "

View File

@ -103,6 +103,10 @@ tf_class {
name: "nrows"
argspec: "args=[\'self\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
}
member_method {
name: "numpy"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "row_lengths"
argspec: "args=[\'self\', \'axis\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "