From 0462de5b544ed4731aa2fb23946ac22c01856b80 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Fri, 18 Sep 2020 15:52:05 -0700 Subject: [PATCH] Validate `data_splits` for `tf.StringNGrams`. Without validation, we can cause a heap buffer overflow which results in data leakage and/or segfaults. PiperOrigin-RevId: 332543478 Change-Id: Iee5bda24497a195d09d122355502480830b1b317 --- tensorflow/core/kernels/string_ngrams_op.cc | 13 ++++++++++++ tensorflow/python/ops/raw_ops_test.py | 23 ++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/string_ngrams_op.cc b/tensorflow/core/kernels/string_ngrams_op.cc index 97b32c4242c..8aed2b3831a 100644 --- a/tensorflow/core/kernels/string_ngrams_op.cc +++ b/tensorflow/core/kernels/string_ngrams_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "absl/strings/ascii.h" #include "absl/strings/str_cat.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace text { @@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel { OP_REQUIRES_OK(context, context->input("data_splits", &splits)); const auto& splits_vec = splits->flat(); + // Validate that the splits are valid indices into data + const int input_data_size = data->flat().size(); + const int splits_vec_size = splits_vec.size(); + for (int i = 0; i < splits_vec_size; ++i) { + bool valid_splits = splits_vec(i) >= 0; + valid_splits = valid_splits && (splits_vec(i) <= input_data_size); + OP_REQUIRES( + context, valid_splits, + errors::InvalidArgument("Invalid split value ", splits_vec(i), + ", must be in [0,", input_data_size, "]")); + } + int num_batch_items = splits_vec.size() - 1; tensorflow::Tensor* ngrams_splits; OP_REQUIRES_OK( diff --git a/tensorflow/python/ops/raw_ops_test.py b/tensorflow/python/ops/raw_ops_test.py index 850e96bb9ed..ee20d58d2f0 100644 --- a/tensorflow/python/ops/raw_ops_test.py +++ b/tensorflow/python/ops/raw_ops_test.py @@ -18,16 +18,21 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized + from tensorflow.python.eager import context from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import gen_string_ops from tensorflow.python.platform import test @test_util.run_all_in_graph_and_eager_modes -class RawOpsTest(test.TestCase): +@test_util.disable_tfrt +class RawOpsTest(test.TestCase, parameterized.TestCase): def testSimple(self): x = constant_op.constant(1) @@ -58,6 +63,22 @@ class RawOpsTest(test.TestCase): gen_math_ops.Any(input=x, axis=0), gen_math_ops.Any(input=x, axis=0, keep_dims=False)) + @parameterized.parameters([[0, 8]], [[-1, 6]]) + def testStringNGramsBadDataSplits(self, splits): + data = ["aa", "bb", "cc", "dd", "ee", "ff"] + with self.assertRaisesRegex(errors.InvalidArgumentError, + "Invalid split value"): + self.evaluate( + gen_string_ops.string_n_grams( + data=data, + data_splits=splits, + separator="", + ngram_widths=[2], + left_pad="", + right_pad="", + pad_width=0, + preserve_short_sequences=False)) + if __name__ == "__main__": ops.enable_eager_execution()