Fix bug in tf.strings.ngrams when data
is a ragged tensor with no values. (The row splits vector was not getting zeroed out.)
PiperOrigin-RevId: 291180958 Change-Id: Id983ca3ccb99e82d7e487a16989fe31b638dbbf3
This commit is contained in:
parent
88fb36a9c3
commit
e6983d538f
@ -60,22 +60,23 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
||||
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
||||
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
||||
|
||||
// If there is no data or size, return an empty RT.
|
||||
if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
|
||||
tensorflow::Tensor* empty;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, data->shape(), &empty));
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(1, splits->shape(), &empty));
|
||||
return;
|
||||
}
|
||||
|
||||
int num_batch_items = splits_vec.size() - 1;
|
||||
tensorflow::Tensor* ngrams_splits;
|
||||
OP_REQUIRES_OK(
|
||||
context, context->allocate_output(1, splits->shape(), &ngrams_splits));
|
||||
auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data();
|
||||
|
||||
// If there is no data or size, return an empty RT.
|
||||
if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) {
|
||||
tensorflow::Tensor* empty;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, data->shape(), &empty));
|
||||
for (int i = 0; i <= num_batch_items; ++i) {
|
||||
ngrams_splits_data[i] = 0;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
ngrams_splits_data[0] = 0;
|
||||
for (int i = 1; i <= num_batch_items; ++i) {
|
||||
int length = splits_vec(i) - splits_vec(i - 1);
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
@ -273,6 +274,14 @@ class StringNgramsTest(test_util.TensorFlowTestCase):
|
||||
], [b"e", b"f", b"g", b"h", b"e|f", b"f|g", b"g|h", b"e|f|g", b"f|g|h"]]
|
||||
self.assertAllEqual(expected_ngrams, result)
|
||||
|
||||
def test_input_with_no_values(self):
|
||||
data = ragged_factory_ops.constant([[], [], []], dtype=dtypes.string)
|
||||
ngram_op = ragged_string_ops.ngrams(data, (1, 2))
|
||||
result = self.evaluate(ngram_op)
|
||||
self.assertAllEqual([0, 0, 0, 0], result.row_splits)
|
||||
self.assertAllEqual(constant_op.constant([], dtype=dtypes.string),
|
||||
result.values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user