Fix multiple vulnerabilities in tf.raw_ops.*CountSparseOutput
.
Also add tests for these API points, both for the happy paths and for the vulnerable ones. PiperOrigin-RevId: 332563222 Change-Id: Ib3b52116a83a134c2e742a7c66e5e956db8fba05
This commit is contained in:
parent
4eab87c67e
commit
3cbb917b47
@ -178,10 +178,30 @@ class SparseCount : public OpKernel {
|
|||||||
const Tensor& weights = context->input(3);
|
const Tensor& weights = context->input(3);
|
||||||
bool use_weights = weights.NumElements() > 0;
|
bool use_weights = weights.NumElements() > 0;
|
||||||
|
|
||||||
|
OP_REQUIRES(context, TensorShapeUtils::IsMatrix(indices.shape()),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Input indices must be a 2-dimensional tensor. Got: ",
|
||||||
|
indices.shape().DebugString()));
|
||||||
|
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, weights.shape() == values.shape(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Weights and values must have the same shape. Weight shape: ",
|
||||||
|
weights.shape().DebugString(),
|
||||||
|
"; values shape: ", values.shape().DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
bool is_1d = shape.NumElements() == 1;
|
bool is_1d = shape.NumElements() == 1;
|
||||||
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
|
int num_batches = is_1d ? 1 : shape.flat<int64>()(0);
|
||||||
int num_values = values.NumElements();
|
int num_values = values.NumElements();
|
||||||
|
|
||||||
|
OP_REQUIRES(context, num_values == indices.shape().dim_size(0),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Number of values must match first dimension of indices.",
|
||||||
|
"Got ", num_values,
|
||||||
|
" values, indices shape: ", indices.shape().DebugString()));
|
||||||
|
|
||||||
const auto indices_values = indices.matrix<int64>();
|
const auto indices_values = indices.matrix<int64>();
|
||||||
const auto values_values = values.flat<T>();
|
const auto values_values = values.flat<T>();
|
||||||
const auto weight_values = weights.flat<W>();
|
const auto weight_values = weights.flat<W>();
|
||||||
@ -235,12 +255,33 @@ class RaggedCount : public OpKernel {
|
|||||||
bool use_weights = weights.NumElements() > 0;
|
bool use_weights = weights.NumElements() > 0;
|
||||||
bool is_1d = false;
|
bool is_1d = false;
|
||||||
|
|
||||||
|
if (use_weights) {
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, weights.shape() == values.shape(),
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Weights and values must have the same shape. Weight shape: ",
|
||||||
|
weights.shape().DebugString(),
|
||||||
|
"; values shape: ", values.shape().DebugString()));
|
||||||
|
}
|
||||||
|
|
||||||
const auto splits_values = splits.flat<int64>();
|
const auto splits_values = splits.flat<int64>();
|
||||||
const auto values_values = values.flat<T>();
|
const auto values_values = values.flat<T>();
|
||||||
const auto weight_values = weights.flat<W>();
|
const auto weight_values = weights.flat<W>();
|
||||||
int num_batches = splits.NumElements() - 1;
|
int num_batches = splits.NumElements() - 1;
|
||||||
int num_values = values.NumElements();
|
int num_values = values.NumElements();
|
||||||
|
|
||||||
|
OP_REQUIRES(
|
||||||
|
context, num_batches > 0,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Must provide at least 2 elements for the splits argument"));
|
||||||
|
OP_REQUIRES(context, splits_values(0) == 0,
|
||||||
|
errors::InvalidArgument("Splits must start with 0, not with ",
|
||||||
|
splits_values(0)));
|
||||||
|
OP_REQUIRES(context, splits_values(num_batches) == num_values,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"Splits must end with the number of values, got ",
|
||||||
|
splits_values(num_batches), " instead of ", num_values));
|
||||||
|
|
||||||
auto per_batch_counts = BatchedMap<W>(num_batches);
|
auto per_batch_counts = BatchedMap<W>(num_batches);
|
||||||
T max_value = 0;
|
T max_value = 0;
|
||||||
int batch_idx = 0;
|
int batch_idx = 0;
|
||||||
|
@ -25,7 +25,9 @@ from tensorflow.python.eager import context
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import bincount_ops
|
from tensorflow.python.ops import bincount_ops
|
||||||
|
from tensorflow.python.ops import gen_count_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_factory_ops
|
from tensorflow.python.ops.ragged import ragged_factory_ops
|
||||||
from tensorflow.python.ops.ragged import ragged_tensor
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
@ -834,5 +836,121 @@ class TestSparseCountFailureModes(test.TestCase):
|
|||||||
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
self.evaluate(bincount_ops.sparse_bincount(x, weights=weights, axis=-1))
|
||||||
|
|
||||||
|
|
||||||
|
@test_util.run_all_in_graph_and_eager_modes
|
||||||
|
@test_util.disable_tfrt
|
||||||
|
class RawOpsTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
|
def testSparseCountSparseOutputBadIndicesShape(self):
|
||||||
|
indices = [[[0], [0]], [[0], [1]], [[1], [0]], [[1], [2]]]
|
||||||
|
values = [1, 1, 1, 10]
|
||||||
|
weights = [1, 2, 4, 6]
|
||||||
|
dense_shape = [2, 3]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Input indices must be a 2-dimensional tensor"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.SparseCountSparseOutput(
|
||||||
|
indices=indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testSparseCountSparseOutputBadWeightsShape(self):
|
||||||
|
indices = [[0, 0], [0, 1], [1, 0], [1, 2]]
|
||||||
|
values = [1, 1, 1, 10]
|
||||||
|
weights = [1, 2, 4]
|
||||||
|
dense_shape = [2, 3]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Weights and values must have the same shape"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.SparseCountSparseOutput(
|
||||||
|
indices=indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testSparseCountSparseOutputBadNumberOfValues(self):
|
||||||
|
indices = [[0, 0], [0, 1], [1, 0]]
|
||||||
|
values = [1, 1, 1, 10]
|
||||||
|
weights = [1, 2, 4, 6]
|
||||||
|
dense_shape = [2, 3]
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Number of values must match first dimension of indices"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.SparseCountSparseOutput(
|
||||||
|
indices=indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=dense_shape,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutput(self):
|
||||||
|
splits = [0, 4, 7]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
output_indices, output_values, output_shape = self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits, values=values, weights=weights, binary_output=False))
|
||||||
|
self.assertAllEqual([[0, 1], [0, 2], [1, 2], [1, 5], [1, 10]],
|
||||||
|
output_indices)
|
||||||
|
self.assertAllEqual([7, 3, 5, 7, 6], output_values)
|
||||||
|
self.assertAllEqual([2, 11], output_shape)
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputBadWeightsShape(self):
|
||||||
|
splits = [0, 4, 7]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Weights and values must have the same shape"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputEmptySplits(self):
|
||||||
|
splits = []
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
errors.InvalidArgumentError,
|
||||||
|
"Must provide at least 2 elements for the splits argument"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputBadSplitsStart(self):
|
||||||
|
splits = [1, 7]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Splits must start with 0"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
def testRaggedCountSparseOutputBadSplitsEnd(self):
|
||||||
|
splits = [0, 5]
|
||||||
|
values = [1, 1, 2, 1, 2, 10, 5]
|
||||||
|
weights = [1, 2, 3, 4, 5, 6, 7]
|
||||||
|
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||||
|
"Splits must end with the number of values"):
|
||||||
|
self.evaluate(
|
||||||
|
gen_count_ops.RaggedCountSparseOutput(
|
||||||
|
splits=splits,
|
||||||
|
values=values,
|
||||||
|
weights=weights,
|
||||||
|
binary_output=False))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user