Merge pull request #45506 from tensorflow/mm-cherrypick-ebc70b7a592420d3d2f359e4b1694c236b82c7ae-on-r2.3
Validate that `DataFormat*` attributes form a permutation.
This commit is contained in:
commit
6eda33cc3d
tensorflow
@ -18,16 +18,52 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include "tensorflow/core/kernels/data_format_ops.h"
|
||||
|
||||
#include <map>
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
// Ensure that `src` and `dst` define a valid permutation.
|
||||
// Ops defined in this file assume that user specifies a permutation via two
|
||||
// string attributes. This check validates that these attributes properly define
|
||||
// it to prevent security vulnerabilities.
|
||||
static bool IsValidPermutation(const std::string& src, const std::string& dst) {
|
||||
if (src.size() != dst.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<char, bool> characters;
|
||||
|
||||
// Every character in `src` must be present only once
|
||||
for (const auto c : src) {
|
||||
if (characters[c]) {
|
||||
return false;
|
||||
}
|
||||
characters[c] = true;
|
||||
}
|
||||
|
||||
// Every character in `dst` must show up in `src` exactly once
|
||||
for (const auto c : dst) {
|
||||
if (!characters[c]) {
|
||||
return false;
|
||||
}
|
||||
characters[c] = false;
|
||||
}
|
||||
|
||||
// At this point, characters[] has been switched to true and false exactly
|
||||
// once for all character in `src` (and `dst`) so we have a valid permutation
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Device, typename T>
|
||||
class DataFormatDimMapOp : public OpKernel {
|
||||
public:
|
||||
@ -37,15 +73,20 @@ class DataFormatDimMapOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||
string dst_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||
OP_REQUIRES(context, src_format.size() == 4,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Source format must of length 4, received src_format = ",
|
||||
src_format)));
|
||||
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||||
errors::InvalidArgument(
|
||||
"Source format must be of length 4 or 5, received "
|
||||
"src_format = ",
|
||||
src_format));
|
||||
OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||
errors::InvalidArgument("Destination format must be of length "
|
||||
"4 or 5, received dst_format = ",
|
||||
dst_format));
|
||||
OP_REQUIRES(
|
||||
context, dst_format.size() == 4,
|
||||
errors::InvalidArgument(strings::StrCat(
|
||||
"Destination format must of length 4, received dst_format = ",
|
||||
dst_format)));
|
||||
context, IsValidPermutation(src_format, dst_format),
|
||||
errors::InvalidArgument(
|
||||
"Destination and source format must determine a permutation, got ",
|
||||
src_format, " and ", dst_format));
|
||||
dst_idx_ = Tensor(DT_INT32, {static_cast<int64>(src_format.size())});
|
||||
for (int i = 0; i < src_format.size(); ++i) {
|
||||
for (int j = 0; j < dst_format.size(); ++j) {
|
||||
@ -77,8 +118,22 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||
: OpKernel(context) {
|
||||
string src_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("src_format", &src_format));
|
||||
OP_REQUIRES(context, src_format.size() == 4 || src_format.size() == 5,
|
||||
errors::InvalidArgument(
|
||||
"Source format must be of length 4 or 5, received "
|
||||
"src_format = ",
|
||||
src_format));
|
||||
string dst_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dst_format", &dst_format));
|
||||
OP_REQUIRES(context, dst_format.size() == 4 || dst_format.size() == 5,
|
||||
errors::InvalidArgument("Destination format must be of length "
|
||||
"4 or 5, received dst_format = ",
|
||||
dst_format));
|
||||
OP_REQUIRES(
|
||||
context, IsValidPermutation(src_format, dst_format),
|
||||
errors::InvalidArgument(
|
||||
"Destination and source format must determine a permutation, got ",
|
||||
src_format, " and ", dst_format));
|
||||
src_format_ = src_format;
|
||||
dst_format_ = dst_format;
|
||||
}
|
||||
@ -124,6 +179,10 @@ class DataFormatVecPermuteOp : public OpKernel {
|
||||
};
|
||||
keep_only_spatial_dimensions(&src_format_str);
|
||||
keep_only_spatial_dimensions(&dst_format_str);
|
||||
OP_REQUIRES(context,
|
||||
src_format_str.size() == 2 && dst_format_str.size() == 2,
|
||||
errors::InvalidArgument(
|
||||
"Format specifier must contain H and W for 2D case"));
|
||||
}
|
||||
ComputeDstIndex(src_format_str, dst_format_str, input.dims(), &dst_idx);
|
||||
|
||||
|
@ -27,6 +27,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_spec
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -1216,6 +1217,46 @@ class DataFormatDimMapTest(test_lib.TestCase):
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, y_val_expected)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testInvalidLength(self):
|
||||
x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Source format must be of length 4 or 5"):
|
||||
op = nn_ops.data_format_dim_map(
|
||||
x, src_format="12345678", dst_format="87654321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testDuplicateSrc(self):
|
||||
x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Destination and source format must determine a permutation"):
|
||||
op = nn_ops.data_format_dim_map(x, src_format="1233", dst_format="4321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testDuplicateDst(self):
|
||||
x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Destination and source format must determine a permutation"):
|
||||
op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="3321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testExtraSpecifiers(self):
|
||||
x = [-4, -3, -2, -1, 0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Destination and source format must determine a permutation"):
|
||||
op = nn_ops.data_format_dim_map(x, src_format="1234", dst_format="5321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
|
||||
class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||
|
||||
@ -1317,6 +1358,60 @@ class DataFormatVectorPermuteTest(test_lib.TestCase):
|
||||
y_val = self.evaluate(y)
|
||||
self.assertAllEqual(y_val, [[7, 4], [4, 5], [5, 1], [9, 3]])
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testInvalidLength(self):
|
||||
x = [0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
"Source format must be of length 4 or 5"):
|
||||
op = nn_ops.data_format_vec_permute(
|
||||
x, src_format="12345678", dst_format="87654321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testDuplicateSrc(self):
|
||||
x = [0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Destination and source format must determine a permutation"):
|
||||
op = nn_ops.data_format_vec_permute(
|
||||
x, src_format="1233", dst_format="4321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testDuplicateDst(self):
|
||||
x = [0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Destination and source format must determine a permutation"):
|
||||
op = nn_ops.data_format_vec_permute(
|
||||
x, src_format="1234", dst_format="3321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def testExtraSpecifiers(self):
|
||||
x = [0, 1, 2, 3]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Destination and source format must determine a permutation"):
|
||||
op = nn_ops.data_format_vec_permute(
|
||||
x, src_format="1234", dst_format="5321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
@test_util.disable_xla("XLA catches the error and rethrows as different one")
|
||||
def test2DNoWH(self):
|
||||
x = [[0, 1], [2, 3]]
|
||||
with self.assertRaisesRegex(
|
||||
errors.InvalidArgumentError,
|
||||
"Format specifier must contain H and W for 2D case"):
|
||||
op = nn_ops.data_format_vec_permute(
|
||||
x, src_format="1234", dst_format="4321")
|
||||
with test_util.use_gpu():
|
||||
self.evaluate(op)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class AvgPoolTest(test_lib.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user