Validate data_splits
for tf.StringNGrams
.
Without validation, we can cause a heap buffer overflow which results in data leakage and/or segfaults. PiperOrigin-RevId: 332543478 Change-Id: Iee5bda24497a195d09d122355502480830b1b317
This commit is contained in:
parent
57f0a5e0b6
commit
0462de5b54
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/strings/ascii.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace text {
|
||||
@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
||||
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
||||
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
||||
|
||||
// Validate that the splits are valid indices into data
|
||||
const int input_data_size = data->flat<tstring>().size();
|
||||
const int splits_vec_size = splits_vec.size();
|
||||
for (int i = 0; i < splits_vec_size; ++i) {
|
||||
bool valid_splits = splits_vec(i) >= 0;
|
||||
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
|
||||
OP_REQUIRES(
|
||||
context, valid_splits,
|
||||
errors::InvalidArgument("Invalid split value ", splits_vec(i),
|
||||
", must be in [0,", input_data_size, "]"));
|
||||
}
|
||||
|
||||
int num_batch_items = splits_vec.size() - 1;
|
||||
tensorflow::Tensor* ngrams_splits;
|
||||
OP_REQUIRES_OK(
|
||||
|
@ -18,16 +18,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gen_string_ops
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class RawOpsTest(test.TestCase):
|
||||
@test_util.disable_tfrt
|
||||
class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testSimple(self):
|
||||
x = constant_op.constant(1)
|
||||
@ -58,6 +63,22 @@ class RawOpsTest(test.TestCase):
|
||||
gen_math_ops.Any(input=x, axis=0),
|
||||
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
|
||||
|
||||
@parameterized.parameters([[0, 8]], [[-1, 6]])
|
||||
def testStringNGramsBadDataSplits(self, splits):
|
||||
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Invalid split value"):
|
||||
self.evaluate(
|
||||
gen_string_ops.string_n_grams(
|
||||
data=data,
|
||||
data_splits=splits,
|
||||
separator="",
|
||||
ngram_widths=[2],
|
||||
left_pad="",
|
||||
right_pad="",
|
||||
pad_width=0,
|
||||
preserve_short_sequences=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user