change MirrorPad packet region
This commit is contained in:
parent
4ce6a9b7a4
commit
a1bdc83cc8
@ -3166,6 +3166,7 @@ tf_cc_tests(
|
|||||||
"adjust_contrast_op_test.cc",
|
"adjust_contrast_op_test.cc",
|
||||||
"colorspace_op_test.cc",
|
"colorspace_op_test.cc",
|
||||||
"crop_and_resize_op_test.cc",
|
"crop_and_resize_op_test.cc",
|
||||||
|
"mirror_pad_op_test.cc",
|
||||||
"non_max_suppression_op_test.cc",
|
"non_max_suppression_op_test.cc",
|
||||||
"resize_area_op_test.cc",
|
"resize_area_op_test.cc",
|
||||||
"resize_bicubic_op_test.cc",
|
"resize_bicubic_op_test.cc",
|
||||||
@ -3178,6 +3179,7 @@ tf_cc_tests(
|
|||||||
}),
|
}),
|
||||||
deps = [
|
deps = [
|
||||||
":image",
|
":image",
|
||||||
|
":mirror_pad_op",
|
||||||
":ops_testutil",
|
":ops_testutil",
|
||||||
":ops_util",
|
":ops_util",
|
||||||
":sampling_kernels",
|
":sampling_kernels",
|
||||||
@ -3244,6 +3246,22 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "mirror_pad_op_benchmark_test",
|
||||||
|
srcs = ["mirror_pad_op_benchmark_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":mirror_pad_op",
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cuda_cc_test(
|
tf_cuda_cc_test(
|
||||||
name = "non_max_suppression_op_gpu_test",
|
name = "non_max_suppression_op_gpu_test",
|
||||||
srcs = ["non_max_suppression_op_gpu_test.cc"],
|
srcs = ["non_max_suppression_op_gpu_test.cc"],
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
|
||||||
#include "tensorflow/core/framework/tensor_types.h"
|
#include "tensorflow/core/framework/tensor_types.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
|
||||||
namespace Eigen {
|
namespace Eigen {
|
||||||
template <typename PaddingDimensions, typename XprType>
|
template <typename PaddingDimensions, typename XprType>
|
||||||
@ -223,7 +223,8 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>,
|
|||||||
const Index right =
|
const Index right =
|
||||||
(dimensions_[dim] - padding_[dim].second) * output_strides_[dim];
|
(dimensions_[dim] - padding_[dim].second) * output_strides_[dim];
|
||||||
|
|
||||||
if (left <= index && (index + kPacketSize - 1) < right) {
|
const Index index_mod = index % (dimensions_[dim] * output_strides_[dim]);
|
||||||
|
if (left <= index_mod && (index_mod + kPacketSize - 1) < right) {
|
||||||
return impl_.template packet<Unaligned>(input_index);
|
return impl_.template packet<Unaligned>(input_index);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
59
tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc
Normal file
59
tensorflow/core/kernels/mirror_pad_op_benchmark_test.cc
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/graph/node_builder.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
static Graph* BM_MirrorPad(int batches, int height, int width, int depth,
|
||||||
|
int pad, const char* mode) {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
Tensor in(DT_FLOAT, TensorShape({batches, height, width, depth}));
|
||||||
|
in.flat<float>().setRandom();
|
||||||
|
Tensor padding(DT_INT32, TensorShape({4, 2}));
|
||||||
|
auto boxes_tensor = padding.flat<int>().setZero();
|
||||||
|
for (int i = 2; i < 6; i++) boxes_tensor(i) = pad;
|
||||||
|
|
||||||
|
Node* ret;
|
||||||
|
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "MirrorPad")
|
||||||
|
.Input(test::graph::Constant(g, in))
|
||||||
|
.Input(test::graph::Constant(g, padding))
|
||||||
|
.Attr("mode", mode)
|
||||||
|
.Finalize(g, &ret));
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
#define BM_MirrorPadDev(DEVICE, B, W, H, D, P, MODE) \
|
||||||
|
static void BM_MirrorPad_##DEVICE##_##B##_##W##_##H##_##D##_##P##_##MODE( \
|
||||||
|
int iters) { \
|
||||||
|
testing::ItemsProcessed(iters* B*(W + 2 * P) * (H + 2 * P) * D / 32); \
|
||||||
|
test::Benchmark(#DEVICE, BM_MirrorPad(B, W, H, D, P, #MODE)).Run(iters); \
|
||||||
|
} \
|
||||||
|
BENCHMARK(BM_MirrorPad_##DEVICE##_##B##_##W##_##H##_##D##_##P##_##MODE);
|
||||||
|
|
||||||
|
BM_MirrorPadDev(cpu, 1, 16, 16, 32, 1, REFLECT);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 16, 16, 32, 8, REFLECT);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 512, 512, 16, 1, REFLECT);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 512, 512, 16, 256, REFLECT);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 16, 16, 32, 1, SYMMETRIC);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 16, 16, 32, 8, SYMMETRIC);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 512, 512, 16, 1, SYMMETRIC);
|
||||||
|
BM_MirrorPadDev(cpu, 1, 512, 512, 16, 256, SYMMETRIC);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
201
tensorflow/core/kernels/mirror_pad_op_test.cc
Normal file
201
tensorflow/core/kernels/mirror_pad_op_test.cc
Normal file
@ -0,0 +1,201 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_util.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/lib/strings/str_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class MirrorPadOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
template <typename T>
|
||||||
|
void MakeOp(const string& mode) {
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("mirror_pad_op", "MirrorPad")
|
||||||
|
.Input(FakeInput(DataTypeToEnum<T>::value))
|
||||||
|
.Input(FakeInput(DT_INT32))
|
||||||
|
.Attr("mode", mode)
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_EXPECT_OK(InitOp());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_TEST(T) \
|
||||||
|
TEST_F(MirrorPadOpTest, TestMirrorPadReflect##T) { \
|
||||||
|
MakeOp<T>("REFLECT"); \
|
||||||
|
AddInputFromArray<T>(TensorShape({1, 2, 3, 1}), {1, 2, 3, 4, 5, 6}); \
|
||||||
|
AddInputFromArray<int32>(TensorShape({4, 2}), {0, 0, 1, 1, 2, 2, 0, 0}); \
|
||||||
|
TF_ASSERT_OK(RunOpKernel()); \
|
||||||
|
\
|
||||||
|
Tensor expected(allocator(), DataTypeToEnum<T>::value, \
|
||||||
|
TensorShape({1, 4, 7, 1})); \
|
||||||
|
test::FillValues<T>(&expected, \
|
||||||
|
{6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, \
|
||||||
|
6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); \
|
||||||
|
test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
TEST_F(MirrorPadOpTest, TestMirrorPadSymmetric##T) { \
|
||||||
|
MakeOp<T>("SYMMETRIC"); \
|
||||||
|
AddInputFromArray<T>(TensorShape({1, 2, 1, 3}), {1, 2, 3, 4, 5, 6}); \
|
||||||
|
AddInputFromArray<int32>(TensorShape({4, 2}), {1, 1, 0, 0, 0, 0, 2, 2}); \
|
||||||
|
TF_ASSERT_OK(RunOpKernel()); \
|
||||||
|
\
|
||||||
|
Tensor expected(allocator(), DataTypeToEnum<T>::value, \
|
||||||
|
TensorShape({3, 2, 1, 7})); \
|
||||||
|
test::FillValues<T>( \
|
||||||
|
&expected, \
|
||||||
|
{2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, \
|
||||||
|
5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5}); \
|
||||||
|
test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_TEST(float)
|
||||||
|
REGISTER_TEST(double)
|
||||||
|
REGISTER_TEST(uint8)
|
||||||
|
REGISTER_TEST(uint16)
|
||||||
|
REGISTER_TEST(int8)
|
||||||
|
REGISTER_TEST(int16)
|
||||||
|
REGISTER_TEST(int32)
|
||||||
|
REGISTER_TEST(int64)
|
||||||
|
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
TEST_F(MirrorPadOpTest, TestMirrorPadReflectLargeInput) {
|
||||||
|
MakeOp<float>("REFLECT");
|
||||||
|
// Generate a relatively large input
|
||||||
|
const int kInput = 1000;
|
||||||
|
const int kPad = 10;
|
||||||
|
const int kOutput = kInput + 2 * kPad;
|
||||||
|
|
||||||
|
// Input:
|
||||||
|
// 0, 1, 2, ..., 999
|
||||||
|
// 0, 1, 2, ..., 999
|
||||||
|
// ... (altogether 1000 lines)
|
||||||
|
// 0, 1, 2, ..., 999
|
||||||
|
AddInput<float>(TensorShape({1, kInput, kInput, 1}),
|
||||||
|
[](int i) -> float { return i % kInput; });
|
||||||
|
AddInputFromArray<int32>(TensorShape({4, 2}),
|
||||||
|
{0, 0, kPad, kPad, kPad, kPad, 0, 0});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, kOutput, kOutput, 1}));
|
||||||
|
test::FillFn<float>(&expected, [](int i) -> float {
|
||||||
|
i = i % kOutput;
|
||||||
|
if (0 <= i && i < kPad)
|
||||||
|
return kPad - i;
|
||||||
|
else if (kPad <= i && i < kInput + kPad)
|
||||||
|
return i - kPad;
|
||||||
|
else if (kInput + kPad <= i && i < kOutput)
|
||||||
|
return 2 * kInput + kPad - 2 - i;
|
||||||
|
});
|
||||||
|
|
||||||
|
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MirrorPadOpTest, TestMirrorPadSymmetricLargeInput) {
|
||||||
|
MakeOp<float>("SYMMETRIC");
|
||||||
|
// Generate a relatively large input
|
||||||
|
const int kInput = 1000;
|
||||||
|
const int kPad = 10;
|
||||||
|
const int kOutput = kInput + 2 * kPad;
|
||||||
|
|
||||||
|
// Input:
|
||||||
|
// 0, 1, 2, ..., 999
|
||||||
|
// 0, 1, 2, ..., 999
|
||||||
|
// ... (altogether 1000 lines)
|
||||||
|
// 0, 1, 2, ..., 999
|
||||||
|
AddInput<float>(TensorShape({1, kInput, kInput, 1}),
|
||||||
|
[](int i) -> float { return i % kInput; });
|
||||||
|
AddInputFromArray<int32>(TensorShape({4, 2}),
|
||||||
|
{0, 0, kPad, kPad, kPad, kPad, 0, 0});
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
Tensor expected(allocator(), DT_FLOAT, TensorShape({1, kOutput, kOutput, 1}));
|
||||||
|
test::FillFn<float>(&expected, [](int i) -> float {
|
||||||
|
i = i % kOutput;
|
||||||
|
if (0 <= i && i < kPad)
|
||||||
|
return kPad - i - 1;
|
||||||
|
else if (kPad <= i && i < kInput + kPad)
|
||||||
|
return i - kPad;
|
||||||
|
else if (kInput + kPad <= i && i < kOutput)
|
||||||
|
return 2 * kInput + kPad - 1 - i;
|
||||||
|
});
|
||||||
|
|
||||||
|
test::ExpectTensorEqual<float>(expected, *GetOutput(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
class MirrorPadGradOpTest : public OpsTestBase {
|
||||||
|
protected:
|
||||||
|
template <typename T>
|
||||||
|
void MakeOp(const string& mode) {
|
||||||
|
TF_EXPECT_OK(NodeDefBuilder("mirror_pad_grad_op", "MirrorPadGrad")
|
||||||
|
.Input(FakeInput(DataTypeToEnum<T>::value))
|
||||||
|
.Input(FakeInput(DT_INT32))
|
||||||
|
.Attr("mode", mode)
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_EXPECT_OK(InitOp());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
#define REGISTER_TEST(T) \
|
||||||
|
TEST_F(MirrorPadGradOpTest, TestMirrorPadGradReflect##T) { \
|
||||||
|
MakeOp<T>("REFLECT"); \
|
||||||
|
AddInput<T>(TensorShape({1, 4, 7, 1}), [](int i) -> T { return i % 7; }); \
|
||||||
|
AddInputFromArray<int32>(TensorShape({4, 2}), {0, 0, 1, 1, 2, 2, 0, 0}); \
|
||||||
|
TF_ASSERT_OK(RunOpKernel()); \
|
||||||
|
\
|
||||||
|
Tensor expected(allocator(), DataTypeToEnum<T>::value, \
|
||||||
|
TensorShape({1, 2, 3, 1})); \
|
||||||
|
test::FillValues<T>(&expected, {16, 18, 8, 16, 18, 8}); \
|
||||||
|
test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
|
||||||
|
} \
|
||||||
|
\
|
||||||
|
TEST_F(MirrorPadGradOpTest, TestMirrorPadGradSymmetric##T) { \
|
||||||
|
MakeOp<T>("SYMMETRIC"); \
|
||||||
|
AddInput<T>(TensorShape({3, 2, 1, 7}), [](int i) -> T { return i % 7; }); \
|
||||||
|
AddInputFromArray<int32>(TensorShape({4, 2}), {1, 1, 0, 0, 0, 0, 2, 2}); \
|
||||||
|
TF_ASSERT_OK(RunOpKernel()); \
|
||||||
|
\
|
||||||
|
Tensor expected(allocator(), DataTypeToEnum<T>::value, \
|
||||||
|
TensorShape({1, 2, 1, 3})); \
|
||||||
|
test::FillValues<T>(&expected, {9, 27, 27, 9, 27, 27}); \
|
||||||
|
test::ExpectTensorEqual<T>(expected, *GetOutput(0)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_TEST(float)
|
||||||
|
REGISTER_TEST(double)
|
||||||
|
REGISTER_TEST(uint8)
|
||||||
|
REGISTER_TEST(uint16)
|
||||||
|
REGISTER_TEST(int8)
|
||||||
|
REGISTER_TEST(int16)
|
||||||
|
REGISTER_TEST(int32)
|
||||||
|
REGISTER_TEST(int64)
|
||||||
|
|
||||||
|
#undef REGISTER_TEST
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
Loading…
x
Reference in New Issue
Block a user