Improve handling of variant-encoded ragged tensors with ragged_rank=0.
This fixes a bug that was preventing tf.data.Dataset.from_tensor_slices from being used with ragged tensors that had ragged_rank=1. PiperOrigin-RevId: 272502146
This commit is contained in:
parent
811e7b67c8
commit
261fef924c
@ -152,7 +152,6 @@ class RaggedTensorToVariantOp : public OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
|
OP_REQUIRES_OK(context, context->input_list("rt_nested_splits",
|
||||||
&ragged_nested_splits_in));
|
&ragged_nested_splits_in));
|
||||||
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
|
const int ragged_nested_splits_len = ragged_nested_splits_in.size();
|
||||||
DCHECK_GT(ragged_nested_splits_len, 0); // Enforced by REGISTER_OP.
|
|
||||||
RaggedTensor batched_ragged_input;
|
RaggedTensor batched_ragged_input;
|
||||||
// Read ragged_values input.
|
// Read ragged_values input.
|
||||||
batched_ragged_input.values = context->input(ragged_nested_splits_len);
|
batched_ragged_input.values = context->input(ragged_nested_splits_len);
|
||||||
|
@ -525,7 +525,10 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestBatched) {
|
|||||||
|
|
||||||
// Tests with len(ragged_splits)==0.
|
// Tests with len(ragged_splits)==0.
|
||||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
|
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
|
||||||
INFER_ERROR("Shape inference should have returned error", op, "?");
|
INFER_ERROR(
|
||||||
|
"ragged_rank=0 is not currently supported "
|
||||||
|
"when batched_input=true.",
|
||||||
|
op, "?");
|
||||||
|
|
||||||
// Tests with len(ragged_splits)==1.
|
// Tests with len(ragged_splits)==1.
|
||||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
|
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
|
||||||
@ -563,7 +566,7 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
|
|||||||
|
|
||||||
// Tests with len(ragged_splits)==0.
|
// Tests with len(ragged_splits)==0.
|
||||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
|
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(0);
|
||||||
INFER_ERROR("Shape inference should have returned error", op, "?");
|
INFER_OK(op, "?", "[]");
|
||||||
|
|
||||||
// Tests with len(ragged_splits)==1.
|
// Tests with len(ragged_splits)==1.
|
||||||
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
|
(*op.node_def.mutable_attr())["RAGGED_RANK"].set_i(1);
|
||||||
@ -592,19 +595,20 @@ TEST_F(RaggedTensorToVariantKernelTest, ShapeFnTestNotBatched) {
|
|||||||
INFER_OK(op, "?;?;?;[5]", "[]");
|
INFER_OK(op, "?;?;?;[5]", "[]");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(RaggedTensorToVariantKernelTest, NoSplits) {
|
TEST_F(RaggedTensorToVariantKernelTest, NonRaggedInput) {
|
||||||
const auto dtype = DataTypeToEnum<int>::v();
|
const std::vector<int> values = {1, 2, 3, 4, 5, 6};
|
||||||
TF_ASSERT_OK(NodeDefBuilder("tested_op", "RaggedTensorToVariant")
|
Tensor expected_values(DT_INT32, TensorShape({6}));
|
||||||
.Input(FakeInput(0))
|
test::FillValues<int>(&expected_values, values);
|
||||||
.Input(FakeInput(dtype))
|
|
||||||
.Attr("RAGGED_RANK", 0)
|
BuildEncodeRaggedTensorGraph<int, int64>({}, TensorShape({6}), values, false);
|
||||||
.Attr("Tvalues", dtype)
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
.Attr("Tsplits", DT_INT64)
|
|
||||||
.Attr("batched_input", true)
|
const auto& encoded_scalar = GetOutput(0)->scalar<Variant>()();
|
||||||
.Finalize(node_def()));
|
const Variant& encoded_values =
|
||||||
EXPECT_TRUE(absl::StartsWith(
|
encoded_scalar.get<Tensor>()->vec<Variant>()(0);
|
||||||
InitOp().error_message(),
|
|
||||||
"Value for attr 'RAGGED_RANK' of 0 must be at least minimum 1"));
|
test::ExpectTensorEqual<int>(*encoded_values.get<Tensor>(), expected_values);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -113,9 +113,9 @@ REGISTER_OP("RaggedTensorToVariant")
|
|||||||
.Input("rt_nested_splits: RAGGED_RANK * Tsplits")
|
.Input("rt_nested_splits: RAGGED_RANK * Tsplits")
|
||||||
.Input("rt_dense_values: Tvalues")
|
.Input("rt_dense_values: Tvalues")
|
||||||
.Output("encoded_ragged: variant")
|
.Output("encoded_ragged: variant")
|
||||||
.Attr("RAGGED_RANK: int >= 1")
|
.Attr("RAGGED_RANK: int >= 0")
|
||||||
.Attr("Tvalues: type")
|
.Attr("Tvalues: type")
|
||||||
.Attr("Tsplits: {int32, int64}")
|
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||||
.Attr("batched_input: bool")
|
.Attr("batched_input: bool")
|
||||||
.SetShapeFn(RaggedTensorToVariantShapeFn);
|
.SetShapeFn(RaggedTensorToVariantShapeFn);
|
||||||
|
|
||||||
@ -124,9 +124,9 @@ REGISTER_OP("RaggedTensorFromVariant")
|
|||||||
.Output("output_nested_splits: output_ragged_rank * Tsplits")
|
.Output("output_nested_splits: output_ragged_rank * Tsplits")
|
||||||
.Output("output_dense_values: Tvalues")
|
.Output("output_dense_values: Tvalues")
|
||||||
.Attr("input_ragged_rank: int >= -1")
|
.Attr("input_ragged_rank: int >= -1")
|
||||||
.Attr("output_ragged_rank: int >= 1")
|
.Attr("output_ragged_rank: int >= 0")
|
||||||
.Attr("Tvalues: type")
|
.Attr("Tvalues: type")
|
||||||
.Attr("Tsplits: {int32, int64}")
|
.Attr("Tsplits: {int32, int64} = DT_INT64")
|
||||||
.SetShapeFn(RaggedTensorFromVariantShapeFn);
|
.SetShapeFn(RaggedTensorFromVariantShapeFn);
|
||||||
|
|
||||||
REGISTER_OP("RaggedTensorToTensor")
|
REGISTER_OP("RaggedTensorToTensor")
|
||||||
@ -194,6 +194,10 @@ Status RaggedTensorToVariantShapeFn(InferenceContext* c) {
|
|||||||
} else {
|
} else {
|
||||||
c->set_output(0, c->Scalar());
|
c->set_output(0, c->Scalar());
|
||||||
}
|
}
|
||||||
|
if (batched && num_splits == 0) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"ragged_rank=0 is not currently supported when batched_input=true.");
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2201,15 +2201,32 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
|||||||
return [tensor_spec.TensorSpec(None, dtypes.variant)]
|
return [tensor_spec.TensorSpec(None, dtypes.variant)]
|
||||||
|
|
||||||
def _to_tensor_list(self, value):
|
def _to_tensor_list(self, value):
|
||||||
|
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||||
|
if ragged_rank != self._ragged_rank:
|
||||||
|
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||||
|
"rank of type (%d)" % (ragged_rank, self._ragged_rank))
|
||||||
|
if ragged_rank == 0:
|
||||||
|
return [
|
||||||
|
gen_ragged_conversion_ops.ragged_tensor_to_variant(
|
||||||
|
(), value, batched_input=False)
|
||||||
|
]
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return [value._to_variant(batched_input=False)]
|
return [value._to_variant(batched_input=False)]
|
||||||
|
|
||||||
def _to_batched_tensor_list(self, value):
|
def _to_batched_tensor_list(self, value):
|
||||||
|
ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
|
||||||
|
if ragged_rank != self._ragged_rank:
|
||||||
|
raise ValueError("Ragged rank of value (%d) does not match ragged "
|
||||||
|
"rank of type (%d)" % (ragged_rank, self._ragged_rank))
|
||||||
|
if ragged_rank == 0:
|
||||||
|
# TODO(b/141789000) Update this to handle ragged_rank=0.
|
||||||
|
raise ValueError(
|
||||||
|
"_to_batched_tensor_list doesn't support ragged_rank=0 yet")
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return [value._to_variant(batched_input=True)]
|
return [value._to_variant(batched_input=True)]
|
||||||
|
|
||||||
def _from_compatible_tensor_list(self, tensor_list):
|
def _from_compatible_tensor_list(self, tensor_list):
|
||||||
if self._ragged_rank <= 0:
|
if self._ragged_rank < 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"ragged_rank must be non-negative; got %s." % self._ragged_rank)
|
"ragged_rank must be non-negative; got %s." % self._ragged_rank)
|
||||||
result = RaggedTensor._from_variant( # pylint: disable=protected-access
|
result = RaggedTensor._from_variant( # pylint: disable=protected-access
|
||||||
@ -2217,12 +2234,15 @@ class RaggedTensorSpec(type_spec.BatchableTypeSpec):
|
|||||||
row_splits_dtype=self._row_splits_dtype,
|
row_splits_dtype=self._row_splits_dtype,
|
||||||
output_ragged_rank=self._ragged_rank)
|
output_ragged_rank=self._ragged_rank)
|
||||||
if self._shape.ndims is not None:
|
if self._shape.ndims is not None:
|
||||||
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
if isinstance(result, RaggedTensor):
|
||||||
if outer_dim is not None:
|
outer_dim = tensor_shape.dimension_value(self._shape[0])
|
||||||
result.row_splits.set_shape([outer_dim + 1])
|
if outer_dim is not None:
|
||||||
result.flat_values.set_shape(
|
result.row_splits.set_shape([outer_dim + 1])
|
||||||
tensor_shape.TensorShape([None]).concatenate(
|
result.flat_values.set_shape(
|
||||||
self._shape[1 + self._ragged_rank:]))
|
tensor_shape.TensorShape([None]).concatenate(
|
||||||
|
self._shape[1 + self._ragged_rank:]))
|
||||||
|
else:
|
||||||
|
result.set_shape(self._shape)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _batch(self, batch_size):
|
def _batch(self, batch_size):
|
||||||
|
@ -23,6 +23,7 @@ import re
|
|||||||
from absl.testing import parameterized
|
from absl.testing import parameterized
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
@ -1700,6 +1701,27 @@ class RaggedTensorTest(test_util.TensorFlowTestCase,
|
|||||||
[[6], [7]], [[8], [9]]])
|
[[6], [7]], [[8], [9]]])
|
||||||
self.assertAllEqual(decoded_rt, expected_rt)
|
self.assertAllEqual(decoded_rt, expected_rt)
|
||||||
|
|
||||||
|
def testUnbatchVariant(self): # b/141789000
|
||||||
|
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
|
||||||
|
batched = rt._to_variant(batched_input=True)
|
||||||
|
for i in range(4):
|
||||||
|
row = RaggedTensor._from_variant(
|
||||||
|
batched[i], dtype=dtypes.int32, output_ragged_rank=0)
|
||||||
|
self.assertAllEqual(rt[i], row)
|
||||||
|
|
||||||
|
def testUnbatchVariantInDataset(self):
|
||||||
|
rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
|
||||||
|
ds = dataset_ops.Dataset.from_tensor_slices(rt)
|
||||||
|
if context.executing_eagerly():
|
||||||
|
for i, value in enumerate(ds):
|
||||||
|
self.assertAllEqual(rt[i], value)
|
||||||
|
else:
|
||||||
|
it = dataset_ops.make_one_shot_iterator(ds)
|
||||||
|
out = it.get_next()
|
||||||
|
with self.cached_session() as sess:
|
||||||
|
for i in range(3):
|
||||||
|
self.assertAllEqual(sess.run(rt[i]), out)
|
||||||
|
|
||||||
def testFromVariantInvalidParams(self):
|
def testFromVariantInvalidParams(self):
|
||||||
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
|
rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
|
||||||
batched_variant = rt._to_variant(batched_input=True)
|
batched_variant = rt._to_variant(batched_input=True)
|
||||||
@ -1847,12 +1869,19 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
|||||||
self.assertEqual(rt_spec._flat_tensor_specs,
|
self.assertEqual(rt_spec._flat_tensor_specs,
|
||||||
[tensor_spec.TensorSpec(None, dtypes.variant)])
|
[tensor_spec.TensorSpec(None, dtypes.variant)])
|
||||||
|
|
||||||
@parameterized.parameters([
|
@parameterized.named_parameters([
|
||||||
{
|
{
|
||||||
|
'testcase_name': 'RaggedRank0',
|
||||||
|
'rt_spec': RaggedTensorSpec(ragged_rank=0),
|
||||||
|
'rt': [1.0, 2.0, 3.0],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'testcase_name': 'RaggedRank1',
|
||||||
'rt_spec': RaggedTensorSpec(ragged_rank=1),
|
'rt_spec': RaggedTensorSpec(ragged_rank=1),
|
||||||
'rt': [[1.0, 2.0], [3.0]]
|
'rt': [[1.0, 2.0], [3.0]]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
'testcase_name': 'RaggedRank2',
|
||||||
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
|
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
|
||||||
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
|
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
|
||||||
},
|
},
|
||||||
@ -1863,6 +1892,28 @@ class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
|
|||||||
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
|
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
|
||||||
self.assertAllEqual(rt, rt_reconstructed)
|
self.assertAllEqual(rt, rt_reconstructed)
|
||||||
|
|
||||||
|
@parameterized.named_parameters([
|
||||||
|
# TODO(b/141789000) Test ragged_rank=0 when support is added.
|
||||||
|
{
|
||||||
|
'testcase_name': 'RaggedRank1',
|
||||||
|
'rt_spec': RaggedTensorSpec(ragged_rank=1),
|
||||||
|
'rt': [[1.0, 2.0], [3.0]]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
'testcase_name': 'RaggedRank2',
|
||||||
|
'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
|
||||||
|
'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
|
||||||
|
},
|
||||||
|
])
|
||||||
|
def testToFromBatchedTensorList(self, rt_spec, rt):
|
||||||
|
rt = ragged_factory_ops.constant(rt)
|
||||||
|
tensor_list = rt_spec._to_batched_tensor_list(rt)
|
||||||
|
rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
|
||||||
|
self.assertAllEqual(rt, rt_reconstructed)
|
||||||
|
first_row = rt_spec._unbatch()._from_tensor_list(
|
||||||
|
[t[0] for t in tensor_list])
|
||||||
|
self.assertAllEqual(rt[0], first_row)
|
||||||
|
|
||||||
@parameterized.parameters([
|
@parameterized.parameters([
|
||||||
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
|
(RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
|
||||||
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
|
RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
|
||||||
|
@ -2882,7 +2882,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedTensorFromVariant"
|
name: "RaggedTensorFromVariant"
|
||||||
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedTensorToSparse"
|
name: "RaggedTensorToSparse"
|
||||||
|
@ -2882,7 +2882,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedTensorFromVariant"
|
name: "RaggedTensorFromVariant"
|
||||||
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'encoded_ragged\', \'input_ragged_rank\', \'output_ragged_rank\', \'Tvalues\', \'Tsplits\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int64\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "RaggedTensorToSparse"
|
name: "RaggedTensorToSparse"
|
||||||
|
Loading…
Reference in New Issue
Block a user