Fix multiple vulnerabilities in tf.raw_ops.*CountSparseOutput.

Also add tests for these API points, both for the happy paths and for the vulnerable ones.

PiperOrigin-RevId: 332563222
Change-Id: Ib3b52116a83a134c2e742a7c66e5e956db8fba05
This commit is contained in:
Mihai Maruseac 2020-09-18 18:15:03 -07:00 committed by TensorFlower Gardener
parent 4eab87c67e
commit 3cbb917b47
2 changed files with 159 additions and 0 deletions
tensorflow
core/kernels
python/ops

View File

@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
const Tensor& weights = context->input(3);
bool use_weights = weights.NumElements() > 0;
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
errors::InvalidArgument(
"Input indices must be a 2-dimensional tensor. Got: ",
indices.shape().DebugString()));
if (use_weights) {
OP_REQUIRES(
context, weights.shape() == values.shape(),
errors::InvalidArgument(
"Weights and values must have the same shape. Weight shape: ",
weights.shape().DebugString(),
"; values shape: ", values.shape().DebugString()));
}
bool is_1d = shape.NumElements() == 1;
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
int num_values = values.NumElements();
OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
errors::InvalidArgument(
"Number of values must match first dimension of indices.",
"Got ", num_values,
" values, indices shape: ", indices.shape().DebugString()));
const auto indices_values = indices.matrix<int64>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
bool use_weights = weights.NumElements() > 0;
bool is_1d = false;
if (use_weights) {
OP_REQUIRES(
context, weights.shape() == values.shape(),
errors::InvalidArgument(
"Weights and values must have the same shape. Weight shape: ",
weights.shape().DebugString(),
"; values shape: ", values.shape().DebugString()));
}
const auto splits_values = splits.flat<int64>();
const auto values_values = values.flat<T>();
const auto weight_values = weights.flat<W>();
int num_batches = splits.NumElements() - 1;
int num_values = values.NumElements();
OP_REQUIRES(
context, num_batches > 0,
errors::InvalidArgument(
"Must provide at least 2 elements for the splits argument"));
OP_REQUIRES(context, splits_values(0) == 0,
errors::InvalidArgument("Splits must start with 0, not with ",
splits_values(0)));
OP_REQUIRES(context, splits_values(num_batches) == num_values,
errors::InvalidArgument(
"Splits must end with the number of values, got ",
splits_values(num_batches), " instead of ", num_values));
auto per_batch_counts = BatchedMap<W>(num_batches);
T max_value = 0;
int batch_idx = 0;

View File

@ -25,7 +25,9 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import bincount_ops
from tensorflow.python.ops import gen_count_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.ragged import ragged_factory_ops
from tensorflow.python.ops.ragged import ragged_tensor
@ -834,5 +836,121 @@ class TestSparseCountFailureModes(test.TestCase):
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
@test_util.run_all_in_graph_and_eager_modes
@test_util.disable_tfrt
class RawOpsTest(test.TestCase, parameterized.TestCase):
def testSparseCountSparseOutputBadIndicesShape(self):
indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
values = [1, 1, 1, 10]
weights = [1, 2, 4, 6]
dense_shape = [2, 3]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Input indices must be a 2-dimensional tensor"):
self.evaluate(
gen_count_ops.SparseCountSparseOutput(
indices=indices,
values=values,
dense_shape=dense_shape,
weights=weights,
binary_output=False))
def testSparseCountSparseOutputBadWeightsShape(self):
indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
values = [1, 1, 1, 10]
weights = [1, 2, 4]
dense_shape = [2, 3]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Weights and values must have the same shape"):
self.evaluate(
gen_count_ops.SparseCountSparseOutput(
indices=indices,
values=values,
dense_shape=dense_shape,
weights=weights,
binary_output=False))
def testSparseCountSparseOutputBadNumberOfValues(self):
indices = [[0, 0], [0, 1], [1, 0]]
values = [1, 1, 1, 10]
weights = [1, 2, 4, 6]
dense_shape = [2, 3]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Number of values must match first dimension of indices"):
self.evaluate(
gen_count_ops.SparseCountSparseOutput(
indices=indices,
values=values,
dense_shape=dense_shape,
weights=weights,
binary_output=False))
def testRaggedCountSparseOutput(self):
splits = [0, 4, 7]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
output_indices, output_values, output_shape = self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits, values=values, weights=weights, binary_output=False))
self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
output_indices)
self.assertAllEqual([7, 3, 5, 7, 6], output_values)
self.assertAllEqual([2, 11], output_shape)
def testRaggedCountSparseOutputBadWeightsShape(self):
splits = [0, 4, 7]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Weights and values must have the same shape"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))
def testRaggedCountSparseOutputEmptySplits(self):
splits = []
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Must provide at least 2 elements for the splits argument"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))
def testRaggedCountSparseOutputBadSplitsStart(self):
splits = [1, 7]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Splits must start with 0"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))
def testRaggedCountSparseOutputBadSplitsEnd(self):
splits = [0, 5]
values = [1, 1, 2, 1, 2, 10, 5]
weights = [1, 2, 3, 4, 5, 6, 7]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Splits must end with the number of values"):
self.evaluate(
gen_count_ops.RaggedCountSparseOutput(
splits=splits,
values=values,
weights=weights,
binary_output=False))
if __name__ == "__main__":
test.main()