Merge pull request #41050 from fsx950223:resize_nearest

PiperOrigin-RevId: 333722178
Change-Id: I75aa25344afd37b3649cf36ad3fecfe069805479
This commit is contained in:
TensorFlower Gardener 2020-09-25 06:43:56 -07:00
commit 4b6afcd5d7
4 changed files with 110 additions and 65 deletions

View File

@ -7464,7 +7464,7 @@ test_suite(
":variable_ops_test", ":variable_ops_test",
"//tensorflow/core/kernels/image:crop_and_resize_op_test", "//tensorflow/core/kernels/image:crop_and_resize_op_test",
"//tensorflow/core/kernels/image:non_max_suppression_op_test", "//tensorflow/core/kernels/image:non_max_suppression_op_test",
"//tensorflow/core/kernels/image:resize_bilinear_op_test", "//tensorflow/core/kernels/image:resize_ops_test",
], ],
) )

View File

@ -326,7 +326,6 @@ tf_cc_tests(
"non_max_suppression_op_test.cc", "non_max_suppression_op_test.cc",
"resize_area_op_test.cc", "resize_area_op_test.cc",
"resize_bicubic_op_test.cc", "resize_bicubic_op_test.cc",
"resize_nearest_neighbor_op_test.cc",
"scale_and_translate_op_test.cc", "scale_and_translate_op_test.cc",
], ],
linkopts = select({ linkopts = select({
@ -349,8 +348,11 @@ tf_cc_test(
) )
tf_cuda_cc_test( tf_cuda_cc_test(
name = "resize_bilinear_op_test", name = "resize_ops_test",
srcs = ["resize_bilinear_op_test.cc"], srcs = [
"resize_bilinear_op_test.cc",
"resize_nearest_neighbor_op_test.cc",
],
tags = ["no_cuda_on_cpu_tap"], tags = ["no_cuda_on_cpu_tap"],
deps = [ deps = [
":image", ":image",

View File

@ -173,20 +173,18 @@ struct ResizeNearestNeighbor<GPUDevice, T, half_pixel_centers, align_corners> {
if (output_size == 0) return true; if (output_size == 0) return true;
GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d); GpuLaunchConfig config = GetGpuLaunchConfig(output_size, d);
if (half_pixel_centers) { void (*kernel)(const int nthreads, const T* __restrict__ bottom_data,
TF_CHECK_OK(GpuLaunchKernel( const int in_height, const int in_width, const int channels,
ResizeNearestNeighborNHWC<T>, config.block_count, const int out_height, const int out_width,
config.thread_per_block, 0, d.stream(), output_size, input.data(), const float height_scale, const float width_scale,
in_height, in_width, channels, out_height, out_width, height_scale, T* top_data) =
width_scale, output.data())); half_pixel_centers ? ResizeNearestNeighborNHWC<T>
return d.ok(); : LegacyResizeNearestNeighborNHWC<T, align_corners>;
} else { TF_CHECK_OK(
TF_CHECK_OK(GpuLaunchKernel( GpuLaunchKernel(kernel, config.block_count, config.thread_per_block, 0,
LegacyResizeNearestNeighborNHWC<T, align_corners>, config.block_count, d.stream(), config.virtual_thread_count, input.data(),
config.thread_per_block, 0, d.stream(), output_size, input.data(), in_height, in_width, channels, out_height, out_width,
in_height, in_width, channels, out_height, out_width, height_scale, height_scale, width_scale, output.data()));
width_scale, output.data()));
}
return d.ok(); return d.ok();
} }
}; };
@ -228,23 +226,20 @@ struct ResizeNearestNeighborGrad<GPUDevice, T, half_pixel_centers,
if (input_size == 0) return true; if (input_size == 0) return true;
GpuLaunchConfig input_config = GetGpuLaunchConfig(input_size, d); GpuLaunchConfig input_config = GetGpuLaunchConfig(input_size, d);
if (half_pixel_centers) { void (*kernel)(const int nthreads, const T* __restrict__ top_diff,
TF_CHECK_OK(GpuLaunchKernel( const int in_height, const int in_width, const int channels,
ResizeNearestNeighborBackwardNHWC<T>, input_config.block_count, const int out_height, const int out_width,
input_config.thread_per_block, 0, d.stream(), const float height_scale, const float width_scale,
input_config.virtual_thread_count, input.data(), in_height, in_width, T* __restrict__ bottom_diff) =
channels, out_height, out_width, height_scale, width_scale, half_pixel_centers
output.data())); ? ResizeNearestNeighborBackwardNHWC<T>
return d.ok(); : LegacyResizeNearestNeighborBackwardNHWC<T, align_corners>;
} else { TF_CHECK_OK(GpuLaunchKernel(
TF_CHECK_OK(GpuLaunchKernel( kernel, input_config.block_count, input_config.thread_per_block, 0,
LegacyResizeNearestNeighborBackwardNHWC<T, align_corners>, d.stream(), input_config.virtual_thread_count, input.data(), in_height,
input_config.block_count, input_config.thread_per_block, 0, in_width, channels, out_height, out_width, height_scale, width_scale,
d.stream(), input_config.virtual_thread_count, input.data(), output.data()));
in_height, in_width, channels, out_height, out_width, height_scale, return d.ok();
width_scale, output.data()));
return d.ok();
}
} }
}; };

View File

@ -16,6 +16,7 @@ limitations under the License.
// TODO(shlens, sherrym): Consider adding additional tests in image_ops.py in // TODO(shlens, sherrym): Consider adding additional tests in image_ops.py in
// order to compare the reference implementation for image resizing in Python // order to compare the reference implementation for image resizing in Python
// Image Library. // Image Library.
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_builder.h"
@ -30,18 +31,32 @@ limitations under the License.
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
namespace tensorflow { namespace tensorflow {
enum class TestDevice { kCPU, kGPU };
class ResizeNearestNeighborOpTestBase : public OpsTestBase { class ResizeNearestNeighborOpTestBase
: public OpsTestBase,
public ::testing::WithParamInterface<TestDevice> {
protected: protected:
explicit ResizeNearestNeighborOpTestBase(bool half_pixel_centers) { explicit ResizeNearestNeighborOpTestBase(bool half_pixel_centers)
TF_EXPECT_OK(NodeDefBuilder("resize_nn", "ResizeNearestNeighbor") : align_corners_(false), half_pixel_centers_(half_pixel_centers) {}
void SetUp() override {
if (GetParam() == TestDevice::kGPU) {
std::unique_ptr<Device> device_gpu(
DeviceFactory::NewDevice(/*type=*/"GPU", /*options=*/{},
/*name_prefix=*/"/job:a/replica:0/task:0"));
SetDevice(DEVICE_GPU, std::move(device_gpu));
}
TF_EXPECT_OK(NodeDefBuilder("resize_nn_op", "ResizeNearestNeighbor")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32)) .Input(FakeInput(DT_INT32))
.Attr("align_corners", false) .Attr("align_corners", align_corners_)
.Attr("half_pixel_centers", half_pixel_centers) .Attr("half_pixel_centers", half_pixel_centers_)
.Finalize(node_def())); .Finalize(node_def()));
TF_EXPECT_OK(InitOp()); TF_EXPECT_OK(InitOp());
} }
bool align_corners_;
bool half_pixel_centers_;
}; };
class ResizeNearestNeighborOpTest : public ResizeNearestNeighborOpTestBase { class ResizeNearestNeighborOpTest : public ResizeNearestNeighborOpTestBase {
@ -58,19 +73,30 @@ class ResizeNearestNeighborHalfPixelCentersOpTest
// TODO(jflynn): Add some actual tests for the half pixel centers case. // TODO(jflynn): Add some actual tests for the half pixel centers case.
class ResizeNearestNeighborOpAlignCornersTest : public OpsTestBase { class ResizeNearestNeighborOpAlignCornersTest
: public OpsTestBase,
public ::testing::WithParamInterface<TestDevice> {
protected: protected:
ResizeNearestNeighborOpAlignCornersTest() { ResizeNearestNeighborOpAlignCornersTest() : align_corners_(true) {}
TF_EXPECT_OK(NodeDefBuilder("resize_nn", "ResizeNearestNeighbor") void SetUp() override {
if (GetParam() == TestDevice::kGPU) {
std::unique_ptr<Device> device_gpu(
DeviceFactory::NewDevice(/*type=*/"GPU", /*options=*/{},
/*name_prefix=*/"/job:a/replica:0/task:0"));
SetDevice(DEVICE_GPU, std::move(device_gpu));
}
TF_EXPECT_OK(NodeDefBuilder("resize_nn_op", "ResizeNearestNeighbor")
.Input(FakeInput(DT_FLOAT)) .Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32)) .Input(FakeInput(DT_INT32))
.Attr("align_corners", true) .Attr("align_corners", align_corners_)
.Finalize(node_def())); .Finalize(node_def()));
TF_EXPECT_OK(InitOp()); TF_EXPECT_OK(InitOp());
} }
bool align_corners_;
}; };
TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To1x1) { TEST_P(ResizeNearestNeighborOpTest, TestNearest2x2To1x1) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -87,7 +113,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To1x1) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpAlignCornersTest, TEST_P(ResizeNearestNeighborOpAlignCornersTest,
TestNearest2x2AlignCornersTo1x1) { TestNearest2x2AlignCornersTo1x1) {
// Input: // Input:
// 1, 2 // 1, 2
@ -105,7 +131,7 @@ TEST_F(ResizeNearestNeighborOpAlignCornersTest,
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To3x3) { TEST_P(ResizeNearestNeighborOpTest, TestNearest2x2To3x3) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -125,7 +151,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To3x3) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpAlignCornersTest, TEST_P(ResizeNearestNeighborOpAlignCornersTest,
TestNearestAlignCorners2x2To3x3) { TestNearestAlignCorners2x2To3x3) {
// Input: // Input:
// 1, 2 // 1, 2
@ -146,7 +172,7 @@ TEST_F(ResizeNearestNeighborOpAlignCornersTest,
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearest3x3To2x2) { TEST_P(ResizeNearestNeighborOpTest, TestNearest3x3To2x2) {
// Input: // Input:
// 1, 2, 3 // 1, 2, 3
// 4, 5, 6 // 4, 5, 6
@ -167,7 +193,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest3x3To2x2) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpAlignCornersTest, TEST_P(ResizeNearestNeighborOpAlignCornersTest,
TestNearestAlignCorners3x3To2x2) { TestNearestAlignCorners3x3To2x2) {
// Input: // Input:
// 1, 2, 3 // 1, 2, 3
@ -189,7 +215,7 @@ TEST_F(ResizeNearestNeighborOpAlignCornersTest,
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To2x5) { TEST_P(ResizeNearestNeighborOpTest, TestNearest2x2To2x5) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -208,7 +234,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To2x5) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearestNeighbor4x4To3x3) { TEST_P(ResizeNearestNeighborOpTest, TestNearestNeighbor4x4To3x3) {
// Input: // Input:
// 1, 2, 3, 4 // 1, 2, 3, 4
// 5, 6, 7, 8 // 5, 6, 7, 8
@ -232,7 +258,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearestNeighbor4x4To3x3) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpAlignCornersTest, TEST_P(ResizeNearestNeighborOpAlignCornersTest,
TestNearestNeighborAlignCorners4x4To3x3) { TestNearestNeighborAlignCorners4x4To3x3) {
// Input: // Input:
// 1, 2, 3, 4 // 1, 2, 3, 4
@ -257,7 +283,7 @@ TEST_F(ResizeNearestNeighborOpAlignCornersTest,
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To5x2) { TEST_P(ResizeNearestNeighborOpTest, TestNearest2x2To5x2) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -279,7 +305,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To5x2) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To4x4) { TEST_P(ResizeNearestNeighborOpTest, TestNearest2x2To4x4) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -300,7 +326,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2To4x4) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2x2x2To2x3x3x2) { TEST_P(ResizeNearestNeighborOpTest, TestNearest2x2x2x2To2x3x3x2) {
// Input: // Input:
// [ [ 1, 1 ], [ 2, 2], // [ [ 1, 1 ], [ 2, 2],
// [ 3, 3 ], [ 4, 4] ], // [ 3, 3 ], [ 4, 4] ],
@ -332,7 +358,7 @@ TEST_F(ResizeNearestNeighborOpTest, TestNearest2x2x2x2To2x3x3x2) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest5x2To2x2) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest5x2To2x2) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -350,7 +376,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest5x2To2x2) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To1x1) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To1x1) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -367,7 +393,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To1x1) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To3x3) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To3x3) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -387,7 +413,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To3x3) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest3x3To2x2) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest3x3To2x2) {
// Input: // Input:
// 1, 2, 3 // 1, 2, 3
// 4, 5, 6 // 4, 5, 6
@ -408,7 +434,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest3x3To2x2) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To2x5) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To2x5) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -427,7 +453,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To2x5) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest,
TestNearestNeighbor4x4To3x3) { TestNearestNeighbor4x4To3x3) {
// Input: // Input:
// 1, 2, 3, 4 // 1, 2, 3, 4
@ -452,7 +478,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest,
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To5x2) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To5x2) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -474,7 +500,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To5x2) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To4x4) { TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To4x4) {
// Input: // Input:
// 1, 2 // 1, 2
// 3, 4 // 3, 4
@ -495,7 +521,7 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TestNearest2x2To4x4) {
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest, TEST_P(ResizeNearestNeighborHalfPixelCentersOpTest,
TestNearest2x2x2x2To2x3x3x2) { TestNearest2x2x2x2To2x3x3x2) {
// Input: // Input:
// [ [ 1, 1 ], [ 2, 2], // [ [ 1, 1 ], [ 2, 2],
@ -521,4 +547,26 @@ TEST_F(ResizeNearestNeighborHalfPixelCentersOpTest,
// clang-format on // clang-format on
test::ExpectTensorEqual<float>(expected, *GetOutput(0)); test::ExpectTensorEqual<float>(expected, *GetOutput(0));
} }
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTestCpu,
ResizeNearestNeighborOpTest,
::testing::Values(TestDevice::kCPU));
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborHalfPixelCentersOpTestCpu,
ResizeNearestNeighborHalfPixelCentersOpTest,
::testing::Values(TestDevice::kCPU));
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpAlignCornersTestCpu,
ResizeNearestNeighborOpAlignCornersTest,
::testing::Values(TestDevice::kCPU));
#if GOOGLE_CUDA
// Instantiate tests for kGPU.
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpTestGpu,
ResizeNearestNeighborOpTest,
::testing::Values(TestDevice::kGPU));
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborHalfPixelCentersOpTestGpu,
ResizeNearestNeighborHalfPixelCentersOpTest,
::testing::Values(TestDevice::kGPU));
INSTANTIATE_TEST_SUITE_P(ResizeNearestNeighborOpAlignCornersTestGpu,
ResizeNearestNeighborOpAlignCornersTest,
::testing::Values(TestDevice::kGPU));
#endif // GOOGLE_CUDA
} // namespace tensorflow } // namespace tensorflow