Improve handling of variant-encoded ragged tensors with ragged_rank=0.
This fixes a bug that was preventing tf.data.Dataset.from_tensor_slices from being used with ragged tensors that had ragged_rank=1. PiperOrigin-RevId: 272502146
This commit is contained in:
parent
811e7b67c8
commit
261fef924c
@ -152,7 +152,6 @@ class RaggedTensorToVariantOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
|
||||
&ragged_nested_splits_in));
|
||||
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
|
||||
DCHECK_GT(ragged_nested_splits_len, 0); // Enforced by REGISTER_OP.
|
||||
RaggedTensor batched_ragged_input;
|
||||
// Read ragged_values input.
|
||||
batched_ragged_input.values = context->input(ragged_nested_splits_len);
|
||||
|
@ -525,7 +525,10 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) {
|
||||
|
||||
// Tests with len(ragged_splits)==0.
|
||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
|
||||
INFER_ERROR("Shape inference should have returned error", op, "?");
|
||||
INFER_ERROR(
|
||||
"ragged_rank=0 is not currently supported "
|
||||
"when batched_input=true.",
|
||||
op, "?");
|
||||
|
||||
// Tests with len(ragged_splits)==1.
|
||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
|
||||
@ -563,7 +566,7 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
|
||||
|
||||
// Tests with len(ragged_splits)==0.
|
||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
|
||||
INFER_ERROR("Shape inference should have returned error", op, "?");
|
||||
INFER_OK(op, "?", "[]");
|
||||
|
||||
// Tests with len(ragged_splits)==1.
|
||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
|
||||
@ -592,19 +595,20 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
|
||||
INFER_OK(op, "?;?;?;[5]", "[]");
|
||||
}
|
||||
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NoSplits) {
|
||||
const auto dtype = DataTypeToEnum<int>::v();
|
||||
TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariant")
|
||||
.Input(FakeInput(0))
|
||||
.Input(FakeInput(dtype))
|
||||
.Attr("RAGGED_RANK", 0)
|
||||
.Attr("Tvalues", dtype)
|
||||
.Attr("Tsplits", DT_INT64)
|
||||
.Attr("batched_input", true)
|
||||
.Finalize(node_def()));
|
||||
EXPECT_TRUE(absl::StartsWith(
|
||||
InitOp().error_message(),
|
||||
"Value for attr 'RAGGED_RANK' of 0 must be at least minimum 1"));
|
||||
TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) {
|
||||
const std::vector<int> values = {1, 2, 3, 4, 5, 6};
|
||||
Tensor expected_values(DT_INT32, TensorShape({6}));
|
||||
test::FillValues<int>(&expected_values, values);
|
||||
|
||||
BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false);
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
|
||||
const Variant& encoded_values =
|
||||
encoded_scalar.get<Tensor>()->vec<Variant>()(0);
|
||||
|
||||
test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -113,9 +113,9 @@ REGISTER_OP("RaggedTensorToVariant")
|
||||
.Input("rt_nested_splits: RAGGED_RANK * Tsplits")
|
||||
.Input("rt_dense_values: Tvalues")
|
||||
.Output("encoded_ragged: variant")
|
||||
.Attr("RAGGED_RANK: int >= 1")
|
||||
.Attr("RAGGED_RANK: int >= 0")
|
||||
.Attr("Tvalues: type")
|
||||
.Attr("Tsplits: {int32, int64}")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.Attr("batched_input: bool")
|
||||
.SetShapeFn(RaggedTensorToVariantShapeFn);
|
||||
|
||||
@ -124,9 +124,9 @@ REGISTER_OP("RaggedTensorFromVariant")
|
||||
.Output("output_nested_splits: output_ragged_rank * Tsplits")
|
||||
.Output("output_dense_values: Tvalues")
|
||||
.Attr("input_ragged_rank: int >= -1")
|
||||
.Attr("output_ragged_rank: int >= 1")
|
||||
.Attr("output_ragged_rank: int >= 0")
|
||||
.Attr("Tvalues: type")
|
||||
.Attr("Tsplits: {int32, int64}")
|
||||
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||
.SetShapeFn(RaggedTensorFromVariantShapeFn);
|
||||
|
||||
REGISTER_OP("RaggedTensorToTensor")
|
||||
@ -194,6 +194,10 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
|
||||
} else {
|
||||
c->set_output(0, c->Scalar());
|
||||
}
|
||||
if (batched && num_splits == 0) {
|
||||
return errors::InvalidArgument(
|
||||
"ragged_rank=0 is not currently supported when batched_input=true.");
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -2201,15 +2201,32 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
return [tensor_spec.TensorSpec(None, dtypes.variant)]
|
||||
|
||||
def _to_tensor_list(self, value):
|
||||
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||
if ragged_rank != self._ragged_rank:
|
||||
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||
"rank of type (%d)" % (ragged_rank, self._ragged_rank))
|
||||
if ragged_rank == 0:
|
||||
return [
|
||||
gen_ragged_conversion_ops.ragged_tensor_to_variant(
|
||||
(), value, batched_input=False)
|
||||
]
|
||||
# pylint: disable=protected-access
|
||||
return [value._to_variant(batched_input=False)]
|
||||
|
||||
def _to_batched_tensor_list(self, value):
|
||||
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||
if ragged_rank != self._ragged_rank:
|
||||
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||
"rank of type (%d)" % (ragged_rank, self._ragged_rank))
|
||||
if ragged_rank == 0:
|
||||
# TODO(b/141789000) Update this to handle ragged_rank=0.
|
||||
raise ValueError(
|
||||
"_to_batched_tensor_list doesn't support ragged_rank=0 yet")
|
||||
# pylint: disable=protected-access
|
||||
return [value._to_variant(batched_input=True)]
|
||||
|
||||
def _from_compatible_tensor_list(self, tensor_list):
|
||||
if self._ragged_rank <= 0:
|
||||
if self._ragged_rank < 0:
|
||||
raise ValueError(
|
||||
"ragged_rank must be non-negative; got %s." % self._ragged_rank)
|
||||
result = RaggedTensor._from_variant( # pylint: disable=protected-access
|
||||
@ -2217,12 +2234,15 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
||||
row_splits_dtype=self._row_splits_dtype,
|
||||
output_ragged_rank=self._ragged_rank)
|
||||
if self._shape.ndims is not None:
|
||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||
if outer_dim is not None:
|
||||
result.row_splits.set_shape([outer_dim + 1])
|
||||
result.flat_values.set_shape(
|
||||
tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[1 + self._ragged_rank:]))
|
||||
if isinstance(result, RaggedTensor):
|
||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||
if outer_dim is not None:
|
||||
result.row_splits.set_shape([outer_dim + 1])
|
||||
result.flat_values.set_shape(
|
||||
tensor_shape.TensorShape([None]).concatenate(
|
||||
self._shape[1 + self._ragged_rank:]))
|
||||
else:
|
||||
result.set_shape(self._shape)
|
||||
return result
|
||||
|
||||
def _batch(self, batch_size):
|
||||
|
@ -23,6 +23,7 @@ import re
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -1700,6 +1701,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
||||
[[6], [7]], [[8], [9]]])
|
||||
self.assertAllEqual(decoded_rt, expected_rt)
|
||||
|
||||
def testUnbatchVariant(self): # b/141789000
|
||||
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
|
||||
batched = rt._to_variant(batched_input=True)
|
||||
for i in range(4):
|
||||
row = RaggedTensor._from_variant(
|
||||
batched[i], dtype=dtypes.int32, output_ragged_rank=0)
|
||||
self.assertAllEqual(rt[i], row)
|
||||
|
||||
def testUnbatchVariantInDataset(self):
|
||||
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(rt)
|
||||
if context.executing_eagerly():
|
||||
for i, value in enumerate(ds):
|
||||
self.assertAllEqual(rt[i], value)
|
||||
else:
|
||||
it = dataset_ops.make_one_shot_iterator(ds)
|
||||
out = it.get_next()
|
||||
with self.cached_session() as sess:
|
||||
for i in range(3):
|
||||
self.assertAllEqual(sess.run(rt[i]), out)
|
||||
|
||||
def testFromVariantInvalidParams(self):
|
||||
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
|
||||
batched_variant = rt._to_variant(batched_input=True)
|
||||
@ -1847,12 +1869,19 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
self.assertEqual(rt_spec._flat_tensor_specs,
|
||||
[tensor_spec.TensorSpec(None, dtypes.variant)])
|
||||
|
||||
@parameterized.parameters([
|
||||
@parameterized.named_parameters([
|
||||
{
|
||||
'testcase_name': 'RaggedRank0',
|
||||
'rt_spec': RaggedTensorSpec(ragged_rank=0),
|
||||
'rt': [1.0, 2.0, 3.0],
|
||||
},
|
||||
{
|
||||
'testcase_name': 'RaggedRank1',
|
||||
'rt_spec': RaggedTensorSpec(ragged_rank=1),
|
||||
'rt': [[1.0, 2.0], [3.0]]
|
||||
},
|
||||
{
|
||||
'testcase_name': 'RaggedRank2',
|
||||
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
|
||||
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
|
||||
},
|
||||
@ -1863,6 +1892,28 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
||||
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
|
||||
self.assertAllEqual(rt, rt_reconstructed)
|
||||
|
||||
@parameterized.named_parameters([
|
||||
# TODO(b/141789000) Test ragged_rank=0 when support is added.
|
||||
{
|
||||
'testcase_name': 'RaggedRank1',
|
||||
'rt_spec': RaggedTensorSpec(ragged_rank=1),
|
||||
'rt': [[1.0, 2.0], [3.0]]
|
||||
},
|
||||
{
|
||||
'testcase_name': 'RaggedRank2',
|
||||
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
|
||||
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
|
||||
},
|
||||
])
|
||||
def testToFromBatchedTensorList(self, rt_spec, rt):
|
||||
rt = ragged_factory_ops.constant(rt)
|
||||
tensor_list = rt_spec._to_batched_tensor_list(rt)
|
||||
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
|
||||
self.assertAllEqual(rt, rt_reconstructed)
|
||||
first_row = rt_spec._unbatch()._from_tensor_list(
|
||||
[t[0] for t in tensor_list])
|
||||
self.assertAllEqual(rt[0], first_row)
|
||||
|
||||
@parameterized.parameters([
|
||||
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
|
||||
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
|
||||
|
@ -2882,7 +2882,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorFromVariant"
|
||||
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorToSparse"
|
||||
|
@ -2882,7 +2882,7 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorFromVariant"
|
||||
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "RaggedTensorToSparse"
|
||||
|
Loading…
Reference in New Issue
Block a user