Fix bug in tf.strings.ngrams when data is a ragged tensor with no values. (The row splits vector was not getting zeroed out.)

PiperOrigin-RevId: 291180958
Change-Id: Id983ca3ccb99e82d7e487a16989fe31b638dbbf3
This commit is contained in:
Edward Loper 2020-01-23 09:20:29 -08:00 committed by TensorFlower Gardener
parent 88fb36a9c3
commit e6983d538f
2 changed files with 20 additions and 10 deletions

View File

@ -60,22 +60,23 @@ class StringNGramsOp : public tensorflow::OpKernel {
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
// If there is no data or size, return an empty RT.
if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
tensorflow::Tensor* empty;
OP_REQUIRES_OK(context,
context->allocate_output(0, data->shape(), &empty));
OP_REQUIRES_OK(context,
context->allocate_output(1, splits->shape(), &empty));
return;
}
int num_batch_items = splits_vec.size() - 1;
tensorflow::Tensor* ngrams_splits;
OP_REQUIRES_OK(
context, context->allocate_output(1, splits->shape(), &ngrams_splits));
auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data();
// If there is no data or size, return an empty RT.
if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
tensorflow::Tensor* empty;
OP_REQUIRES_OK(context,
context->allocate_output(0, data->shape(), &empty));
for (int i = 0; i <= num_batch_items; ++i) {
ngrams_splits_data[i] = 0;
}
return;
}
ngrams_splits_data[0] = 0;
for (int i = 1; i <= num_batch_items; ++i) {
int length = splits_vec(i) - splits_vec(i - 1);

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops.ragged import ragged_factory_ops
@ -273,6 +274,14 @@ class StringNgramsTest(test_util.TensorFlowTestCase):
], [b"e", b"f", b"g", b"h", b"e|f", b"f|g", b"g|h", b"e|f|g", b"f|g|h"]]
self.assertAllEqual(expected_ngrams, result)
def test_input_with_no_values(self):
data = ragged_factory_ops.constant([[], [], []], dtype=dtypes.string)
ngram_op = ragged_string_ops.ngrams(data, (1, 2))
result = self.evaluate(ngram_op)
self.assertAllEqual([0, 0, 0, 0], result.row_splits)
self.assertAllEqual(constant_op.constant([], dtype=dtypes.string),
result.values)
if __name__ == "__main__":
test.main()