Fix multiple vulnerabilities in tf.raw_ops.*CountSparseOutput
.
Also add tests for these API points, both for the happy paths and for the vulnerable ones. PiperOrigin-RevId: 332563222 Change-Id: Ib3b52116a83a134c2e742a7c66e5e956db8fba05
This commit is contained in:
parent
4eab87c67e
commit
3cbb917b47
tensorflow
@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
|
||||
const Tensor& weights = context->input(3);
|
||||
bool use_weights = weights.NumElements() > 0;
|
||||
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Input indices must be a 2-dimensional tensor. Got: ",
|
||||
indices.shape().DebugString()));
|
||||
|
||||
if (use_weights) {
|
||||
OP_REQUIRES(
|
||||
context, weights.shape() == values.shape(),
|
||||
errors::InvalidArgument(
|
||||
"Weights and values must have the same shape. Weight shape: ",
|
||||
weights.shape().DebugString(),
|
||||
"; values shape: ", values.shape().DebugString()));
|
||||
}
|
||||
|
||||
bool is_1d = shape.NumElements() == 1;
|
||||
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
|
||||
int num_values = values.NumElements();
|
||||
|
||||
OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"Number of values must match first dimension of indices.",
|
||||
"Got ", num_values,
|
||||
" values, indices shape: ", indices.shape().DebugString()));
|
||||
|
||||
const auto indices_values = indices.matrix<int64>();
|
||||
const auto values_values = values.flat<T>();
|
||||
const auto weight_values = weights.flat<W>();
|
||||
@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
|
||||
bool use_weights = weights.NumElements() > 0;
|
||||
bool is_1d = false;
|
||||
|
||||
if (use_weights) {
|
||||
OP_REQUIRES(
|
||||
context, weights.shape() == values.shape(),
|
||||
errors::InvalidArgument(
|
||||
"Weights and values must have the same shape. Weight shape: ",
|
||||
weights.shape().DebugString(),
|
||||
"; values shape: ", values.shape().DebugString()));
|
||||
}
|
||||
|
||||
const auto splits_values = splits.flat<int64>();
|
||||
const auto values_values = values.flat<T>();
|
||||
const auto weight_values = weights.flat<W>();
|
||||
int num_batches = splits.NumElements() - 1;
|
||||
int num_values = values.NumElements();
|
||||
|
||||
OP_REQUIRES(
|
||||
context, num_batches > 0,
|
||||
errors::InvalidArgument(
|
||||
"Must provide at least 2 elements for the splits argument"));
|
||||
OP_REQUIRES(context, splits_values(0) == 0,
|
||||
errors::InvalidArgument("Splits must start with 0, not with ",
|
||||
splits_values(0)));
|
||||
OP_REQUIRES(context, splits_values(num_batches) == num_values,
|
||||
errors::InvalidArgument(
|
||||
"Splits must end with the number of values, got ",
|
||||
splits_values(num_batches), " instead of ", num_values));
|
||||
|
||||
auto per_batch_counts = BatchedMap<W>(num_batches);
|
||||
T max_value = 0;
|
||||
int batch_idx = 0;
|
||||
|
@ -25,7 +25,9 @@ from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import bincount_ops
|
||||
from tensorflow.python.ops import gen_count_ops
|
||||
from tensorflow.python.ops import sparse_ops
|
||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
@ -834,5 +836,121 @@ class TestSparseCountFailureModes(test.TestCase):
|
||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
@test_util.disable_tfrt
|
||||
class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
def testSparseCountSparseOutputBadIndicesShape(self):
|
||||
indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
|
||||
values = [1, 1, 1, 10]
|
||||
weights = [1, 2, 4, 6]
|
||||
dense_shape = [2, 3]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Input indices must be a 2-dimensional tensor"):
|
||||
self.evaluate(
|
||||
gen_count_ops.SparseCountSparseOutput(
|
||||
indices=indices,
|
||||
values=values,
|
||||
dense_shape=dense_shape,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
def testSparseCountSparseOutputBadWeightsShape(self):
|
||||
indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
|
||||
values = [1, 1, 1, 10]
|
||||
weights = [1, 2, 4]
|
||||
dense_shape = [2, 3]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Weights and values must have the same shape"):
|
||||
self.evaluate(
|
||||
gen_count_ops.SparseCountSparseOutput(
|
||||
indices=indices,
|
||||
values=values,
|
||||
dense_shape=dense_shape,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
def testSparseCountSparseOutputBadNumberOfValues(self):
|
||||
indices = [[0, 0], [0, 1], [1, 0]]
|
||||
values = [1, 1, 1, 10]
|
||||
weights = [1, 2, 4, 6]
|
||||
dense_shape = [2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Number of values must match first dimension of indices"):
|
||||
self.evaluate(
|
||||
gen_count_ops.SparseCountSparseOutput(
|
||||
indices=indices,
|
||||
values=values,
|
||||
dense_shape=dense_shape,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
def testRaggedCountSparseOutput(self):
|
||||
splits = [0, 4, 7]
|
||||
values = [1, 1, 2, 1, 2, 10, 5]
|
||||
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||
output_indices, output_values, output_shape = self.evaluate(
|
||||
gen_count_ops.RaggedCountSparseOutput(
|
||||
splits=splits, values=values, weights=weights, binary_output=False))
|
||||
self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
|
||||
output_indices)
|
||||
self.assertAllEqual([7, 3, 5, 7, 6], output_values)
|
||||
self.assertAllEqual([2, 11], output_shape)
|
||||
|
||||
def testRaggedCountSparseOutputBadWeightsShape(self):
|
||||
splits = [0, 4, 7]
|
||||
values = [1, 1, 2, 1, 2, 10, 5]
|
||||
weights = [1, 2, 3, 4, 5, 6]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Weights and values must have the same shape"):
|
||||
self.evaluate(
|
||||
gen_count_ops.RaggedCountSparseOutput(
|
||||
splits=splits,
|
||||
values=values,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
def testRaggedCountSparseOutputEmptySplits(self):
|
||||
splits = []
|
||||
values = [1, 1, 2, 1, 2, 10, 5]
|
||||
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Must provide at least 2 elements for the splits argument"):
|
||||
self.evaluate(
|
||||
gen_count_ops.RaggedCountSparseOutput(
|
||||
splits=splits,
|
||||
values=values,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
def testRaggedCountSparseOutputBadSplitsStart(self):
|
||||
splits = [1, 7]
|
||||
values = [1, 1, 2, 1, 2, 10, 5]
|
||||
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Splits must start with 0"):
|
||||
self.evaluate(
|
||||
gen_count_ops.RaggedCountSparseOutput(
|
||||
splits=splits,
|
||||
values=values,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
def testRaggedCountSparseOutputBadSplitsEnd(self):
|
||||
splits = [0, 5]
|
||||
values = [1, 1, 2, 1, 2, 10, 5]
|
||||
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Splits must end with the number of values"):
|
||||
self.evaluate(
|
||||
gen_count_ops.RaggedCountSparseOutput(
|
||||
splits=splits,
|
||||
values=values,
|
||||
weights=weights,
|
||||
binary_output=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user