Improve handling of variant-encoded ragged tensors with ragged_rank=0.

This fixes a bug that was preventing tf.data.Dataset.from_tensor_slices from being used with ragged tensors that had ragged_rank=1.

PiperOrigin-RevId: 272502146
This commit is contained in:
Edward Loper 2019-10-02 13:06:18 -07:00 committed by TensorFlower Gardener
parent 811e7b67c8
commit 261fef924c
7 changed files with 108 additions and 30 deletions

View File

@ -152,7 +152,6 @@ class RaggedTensorToVariantOp : public OpKernel {
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits", OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
&ragged_nested_splits_in)); &ragged_nested_splits_in));
const int ragged_nested_splits_len = ragged_nested_splits_in.size(); const int ragged_nested_splits_len = ragged_nested_splits_in.size();
DCHECK_GT(ragged_nested_splits_len, 0); // Enforced by REGISTER_OP.
RaggedTensor batched_ragged_input; RaggedTensor batched_ragged_input;
// Read ragged_values input. // Read ragged_values input.
batched_ragged_input.values = context->input(ragged_nested_splits_len); batched_ragged_input.values = context->input(ragged_nested_splits_len);

View File

@ -525,7 +525,10 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) {
// Tests with len(ragged_splits)==0. // Tests with len(ragged_splits)==0.
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0); (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
INFER_ERROR("Shape inference should have returned error", op, "?"); INFER_ERROR(
"ragged_rank=0 is not currently supported "
"when batched_input=true.",
op, "?");
// Tests with len(ragged_splits)==1. // Tests with len(ragged_splits)==1.
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1); (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
@ -563,7 +566,7 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
// Tests with len(ragged_splits)==0. // Tests with len(ragged_splits)==0.
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0); (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
INFER_ERROR("Shape inference should have returned error", op, "?"); INFER_OK(op, "?", "[]");
// Tests with len(ragged_splits)==1. // Tests with len(ragged_splits)==1.
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1); (*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
@ -592,19 +595,20 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
INFER_OK(op, "?;?;?;[5]", "[]"); INFER_OK(op, "?;?;?;[5]", "[]");
} }
TEST_F(RaggedTensorToVariantKernelTest, NoSplits) { TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) {
const auto dtype = DataTypeToEnum<int>::v(); const std::vector<int> values = {1, 2, 3, 4, 5, 6};
TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariant") Tensor expected_values(DT_INT32, TensorShape({6}));
.Input(FakeInput(0)) test::FillValues<int>(&expected_values, values);
.Input(FakeInput(dtype))
.Attr("RAGGED_RANK", 0) BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false);
.Attr("Tvalues", dtype) TF_ASSERT_OK(RunOpKernel());
.Attr("Tsplits", DT_INT64)
.Attr("batched_input", true) const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
.Finalize(node_def())); const Variant& encoded_values =
EXPECT_TRUE(absl::StartsWith( encoded_scalar.get<Tensor>()->vec<Variant>()(0);
InitOp().error_message(),
"Value for attr 'RAGGED_RANK' of 0 must be at least minimum 1")); test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values);
} }
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow

View File

@ -113,9 +113,9 @@ REGISTER_OP("RaggedTensorToVariant")
.Input("rt_nested_splits: RAGGED_RANK * Tsplits") .Input("rt_nested_splits: RAGGED_RANK * Tsplits")
.Input("rt_dense_values: Tvalues") .Input("rt_dense_values: Tvalues")
.Output("encoded_ragged: variant") .Output("encoded_ragged: variant")
.Attr("RAGGED_RANK: int >= 1") .Attr("RAGGED_RANK: int >= 0")
.Attr("Tvalues: type") .Attr("Tvalues: type")
.Attr("Tsplits: {int32, int64}") .Attr("Tsplits: {int32, int64} = DT_INT64")
.Attr("batched_input: bool") .Attr("batched_input: bool")
.SetShapeFn(RaggedTensorToVariantShapeFn); .SetShapeFn(RaggedTensorToVariantShapeFn);
@ -124,9 +124,9 @@ REGISTER_OP("RaggedTensorFromVariant")
.Output("output_nested_splits: output_ragged_rank * Tsplits") .Output("output_nested_splits: output_ragged_rank * Tsplits")
.Output("output_dense_values: Tvalues") .Output("output_dense_values: Tvalues")
.Attr("input_ragged_rank: int >= -1") .Attr("input_ragged_rank: int >= -1")
.Attr("output_ragged_rank: int >= 1") .Attr("output_ragged_rank: int >= 0")
.Attr("Tvalues: type") .Attr("Tvalues: type")
.Attr("Tsplits: {int32, int64}") .Attr("Tsplits: {int32, int64} = DT_INT64")
.SetShapeFn(RaggedTensorFromVariantShapeFn); .SetShapeFn(RaggedTensorFromVariantShapeFn);
REGISTER_OP("RaggedTensorToTensor") REGISTER_OP("RaggedTensorToTensor")
@ -194,6 +194,10 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
} else { } else {
c->set_output(0, c->Scalar()); c->set_output(0, c->Scalar());
} }
if (batched && num_splits == 0) {
return errors::InvalidArgument(
"ragged_rank=0 is not currently supported when batched_input=true.");
}
return Status::OK(); return Status::OK();
} }

View File

@ -2201,15 +2201,32 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
return [tensor_spec.TensorSpec(None, dtypes.variant)] return [tensor_spec.TensorSpec(None, dtypes.variant)]
def _to_tensor_list(self, value): def _to_tensor_list(self, value):
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
if ragged_rank != self._ragged_rank:
raise ValueError("Ragged rank of value (%d) does not match ragged "
"rank of type (%d)" % (ragged_rank, self._ragged_rank))
if ragged_rank == 0:
return [
gen_ragged_conversion_ops.ragged_tensor_to_variant(
(), value, batched_input=False)
]
# pylint: disable=protected-access # pylint: disable=protected-access
return [value._to_variant(batched_input=False)] return [value._to_variant(batched_input=False)]
def _to_batched_tensor_list(self, value): def _to_batched_tensor_list(self, value):
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
if ragged_rank != self._ragged_rank:
raise ValueError("Ragged rank of value (%d) does not match ragged "
"rank of type (%d)" % (ragged_rank, self._ragged_rank))
if ragged_rank == 0:
# TODO(b/141789000) Update this to handle ragged_rank=0.
raise ValueError(
"_to_batched_tensor_list doesn't support ragged_rank=0 yet")
# pylint: disable=protected-access # pylint: disable=protected-access
return [value._to_variant(batched_input=True)] return [value._to_variant(batched_input=True)]
def _from_compatible_tensor_list(self, tensor_list): def _from_compatible_tensor_list(self, tensor_list):
if self._ragged_rank <= 0: if self._ragged_rank < 0:
raise ValueError( raise ValueError(
"ragged_rank must be non-negative; got %s." % self._ragged_rank) "ragged_rank must be non-negative; got %s." % self._ragged_rank)
result = RaggedTensor._from_variant( # pylint: disable=protected-access result = RaggedTensor._from_variant( # pylint: disable=protected-access
@ -2217,12 +2234,15 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
row_splits_dtype=self._row_splits_dtype, row_splits_dtype=self._row_splits_dtype,
output_ragged_rank=self._ragged_rank) output_ragged_rank=self._ragged_rank)
if self._shape.ndims is not None: if self._shape.ndims is not None:
outer_dim = tensor_shape.dimension_value(self._shape[0]) if isinstance(result, RaggedTensor):
if outer_dim is not None: outer_dim = tensor_shape.dimension_value(self._shape[0])
result.row_splits.set_shape([outer_dim + 1]) if outer_dim is not None:
result.flat_values.set_shape( result.row_splits.set_shape([outer_dim + 1])
tensor_shape.TensorShape([None]).concatenate( result.flat_values.set_shape(
self._shape[1 + self._ragged_rank:])) tensor_shape.TensorShape([None]).concatenate(
self._shape[1 + self._ragged_rank:]))
else:
result.set_shape(self._shape)
return result return result
def _batch(self, batch_size): def _batch(self, batch_size):

View File

@ -23,6 +23,7 @@ import re
from absl.testing import parameterized from absl.testing import parameterized
import numpy as np import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
@ -1700,6 +1701,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
[[6], [7]], [[8], [9]]]) [[6], [7]], [[8], [9]]])
self.assertAllEqual(decoded_rt, expected_rt) self.assertAllEqual(decoded_rt, expected_rt)
def testUnbatchVariant(self): # b/141789000
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
batched = rt._to_variant(batched_input=True)
for i in range(4):
row = RaggedTensor._from_variant(
batched[i], dtype=dtypes.int32, output_ragged_rank=0)
self.assertAllEqual(rt[i], row)
def testUnbatchVariantInDataset(self):
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
ds = dataset_ops.Dataset.from_tensor_slices(rt)
if context.executing_eagerly():
for i, value in enumerate(ds):
self.assertAllEqual(rt[i], value)
else:
it = dataset_ops.make_one_shot_iterator(ds)
out = it.get_next()
with self.cached_session() as sess:
for i in range(3):
self.assertAllEqual(sess.run(rt[i]), out)
def testFromVariantInvalidParams(self): def testFromVariantInvalidParams(self):
rt = ragged_factory_ops.constant([[0], [1], [2], [3]]) rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
batched_variant = rt._to_variant(batched_input=True) batched_variant = rt._to_variant(batched_input=True)
@ -1847,12 +1869,19 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
self.assertEqual(rt_spec._flat_tensor_specs, self.assertEqual(rt_spec._flat_tensor_specs,
[tensor_spec.TensorSpec(None, dtypes.variant)]) [tensor_spec.TensorSpec(None, dtypes.variant)])
@parameterized.parameters([ @parameterized.named_parameters([
{ {
'testcase_name': 'RaggedRank0',
'rt_spec': RaggedTensorSpec(ragged_rank=0),
'rt': [1.0, 2.0, 3.0],
},
{
'testcase_name': 'RaggedRank1',
'rt_spec': RaggedTensorSpec(ragged_rank=1), 'rt_spec': RaggedTensorSpec(ragged_rank=1),
'rt': [[1.0, 2.0], [3.0]] 'rt': [[1.0, 2.0], [3.0]]
}, },
{ {
'testcase_name': 'RaggedRank2',
'rt_spec': RaggedTensorSpec(shape=[2, None, None]), 'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]] 'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
}, },
@ -1863,6 +1892,28 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
rt_reconstructed = rt_spec._from_tensor_list(tensor_list) rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
self.assertAllEqual(rt, rt_reconstructed) self.assertAllEqual(rt, rt_reconstructed)
@parameterized.named_parameters([
# TODO(b/141789000) Test ragged_rank=0 when support is added.
{
'testcase_name': 'RaggedRank1',
'rt_spec': RaggedTensorSpec(ragged_rank=1),
'rt': [[1.0, 2.0], [3.0]]
},
{
'testcase_name': 'RaggedRank2',
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
},
])
def testToFromBatchedTensorList(self, rt_spec, rt):
rt = ragged_factory_ops.constant(rt)
tensor_list = rt_spec._to_batched_tensor_list(rt)
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
self.assertAllEqual(rt, rt_reconstructed)
first_row = rt_spec._unbatch()._from_tensor_list(
[t[0] for t in tensor_list])
self.assertAllEqual(rt[0], first_row)
@parameterized.parameters([ @parameterized.parameters([
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32, (RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)), RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),

View File

@ -2882,7 +2882,7 @@ tf_module {
} }
member_method { member_method {
name: "RaggedTensorFromVariant" name: "RaggedTensorFromVariant"
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "RaggedTensorToSparse" name: "RaggedTensorToSparse"

View File

@ -2882,7 +2882,7 @@ tf_module {
} }
member_method { member_method {
name: "RaggedTensorFromVariant" name: "RaggedTensorFromVariant"
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
} }
member_method { member_method {
name: "RaggedTensorToSparse" name: "RaggedTensorToSparse"