Validate data_splits
for tf.StringNGrams
.
Without validation, we can cause a heap buffer overflow which results in data leakage and/or segfaults. PiperOrigin-RevId: 332543478 Change-Id: Iee5bda24497a195d09d122355502480830b1b317
This commit is contained in:
parent
57f0a5e0b6
commit
0462de5b54
@ -19,6 +19,7 @@ limitations under the License.
|
|||||||
#include "absl/strings/ascii.h"
|
#include "absl/strings/ascii.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace text {
|
namespace text {
|
||||||
@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel {
|
|||||||
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
|
||||||
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
|
||||||
|
|
||||||
|
// Validate that the splits are valid indices into data
|
||||||
|
const int input_data_size = data->flat<tstring>().size();
|
||||||
|
const int splits_vec_size = splits_vec.size();
|
||||||
|
for (int i = 0; i < splits_vec_size; ++i) {
|
||||||
|
bool valid_splits = splits_vec(i) >= 0;
|
||||||
|
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, valid_splits,
|
||||||
|
errors::InvalidArgument("Invalid split value ", splits_vec(i),
|
||||||
|
", must be in [0,", input_data_size, "]"));
|
||||||
|
}
|
||||||
|
|
||||||
int num_batch_items = splits_vec.size() - 1;
|
int num_batch_items = splits_vec.size() - 1;
|
||||||
tensorflow::Tensor* ngrams_splits;
|
tensorflow::Tensor* ngrams_splits;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK(
|
||||||
|
@ -18,16 +18,21 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from absl.testing import parameterized
|
||||||
|
|
||||||
from tensorflow.python.eager import context
|
from tensorflow.python.eager import context
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_math_ops
|
from tensorflow.python.ops import gen_math_ops
|
||||||
|
from tensorflow.python.ops import gen_string_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@test_util.run_all_in_graph_and_eager_modes
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
class RawOpsTest(test.TestCase):
|
@test_util.disable_tfrt
|
||||||
|
class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
def testSimple(self):
|
def testSimple(self):
|
||||||
x = constant_op.constant(1)
|
x = constant_op.constant(1)
|
||||||
@ -58,6 +63,22 @@ class RawOpsTest(test.TestCase):
|
|||||||
gen_math_ops.Any(input=x, axis=0),
|
gen_math_ops.Any(input=x, axis=0),
|
||||||
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
|
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
|
||||||
|
|
||||||
|
@parameterized.parameters([[0, 8]], [[-1, 6]])
|
||||||
|
def testStringNGramsBadDataSplits(self, splits):
|
||||||
|
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Invalid split value"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_string_ops.string_n_grams(
|
||||||
|
data=data,
|
||||||
|
data_splits=splits,
|
||||||
|
separator="",
|
||||||
|
ngram_widths=[2],
|
||||||
|
left_pad="",
|
||||||
|
right_pad="",
|
||||||
|
pad_width=0,
|
||||||
|
preserve_short_sequences=False))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
ops.enable_eager_execution()
|
ops.enable_eager_execution()
|
||||||
|
Loading…
Reference in New Issue
Block a user