diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc index dc757a01fcf..97b32c4242c 100644 --- a/tensorflow/core/kernels/string_ngrams_op.cc +++ b/tensorflow/core/kernels/string_ngrams_op.cc @@ -60,22 +60,23 @@ class StringNGramsOp : public tensorflow::OpKernel { OP_REQUIRES_OK(context, context->input("data_splits", &splits)); const auto& splits_vec = splits->flat<SPLITS_TYPE>(); - // If there is no data or size, return an empty RT. - if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) { - tensorflow::Tensor* empty; - OP_REQUIRES_OK(context, - context->allocate_output(0, data->shape(), &empty)); - OP_REQUIRES_OK(context, - context->allocate_output(1, splits->shape(), &empty)); - return; - } - int num_batch_items = splits_vec.size() - 1; tensorflow::Tensor* ngrams_splits; OP_REQUIRES_OK( context, context->allocate_output(1, splits->shape(), &ngrams_splits)); auto ngrams_splits_data = ngrams_splits->flat<SPLITS_TYPE>().data(); + // If there is no data or size, return an empty RT. + if (data->flat<tstring>().size() == 0 || splits_vec.size() == 0) { + tensorflow::Tensor* empty; + OP_REQUIRES_OK(context, + context->allocate_output(0, data->shape(), &empty)); + for (int i = 0; i <= num_batch_items; ++i) { + ngrams_splits_data[i] = 0; + } + return; + } + ngrams_splits_data[0] = 0; for (int i = 1; i <= num_batch_items; ++i) { int length = splits_vec(i) - splits_vec(i - 1); diff --git a/tensorflow/python/ops/ragged/string_ngrams_op_test.py b/tensorflow/python/ops/ragged/string_ngrams_op_test.py index 464eb3bb7f5..6b3b3777cb5 100644 --- a/tensorflow/python/ops/ragged/string_ngrams_op_test.py +++ b/tensorflow/python/ops/ragged/string_ngrams_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops.ragged import ragged_factory_ops @@ -273,6 +274,14 @@ class StringNgramsTest(test_util.TensorFlowTestCase): ], [b"e", b"f", b"g", b"h", b"e|f", b"f|g", b"g|h", b"e|f|g", b"f|g|h"]] self.assertAllEqual(expected_ngrams, result) + def test_input_with_no_values(self): + data = ragged_factory_ops.constant([[], [], []], dtype=dtypes.string) + ngram_op = ragged_string_ops.ngrams(data, (1, 2)) + result = self.evaluate(ngram_op) + self.assertAllEqual([0, 0, 0, 0], result.row_splits) + self.assertAllEqual(constant_op.constant([], dtype=dtypes.string), + result.values) + if __name__ == "__main__": test.main()