diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 7e6e1e9bc96..d9b4d67ae0b 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1680,6 +1680,7 @@ tf_cc_tests( "colorspace_op_test.cc", "crop_and_resize_op_test.cc", "non_max_suppression_op_test.cc", + "resize_area_op_test.cc", "resize_bicubic_op_test.cc", "resize_bilinear_op_test.cc", "resize_nearest_neighbor_op_test.cc", diff --git a/tensorflow/core/kernels/resize_area_op.cc b/tensorflow/core/kernels/resize_area_op.cc index cb653a05717..ad94de89dba 100644 --- a/tensorflow/core/kernels/resize_area_op.cc +++ b/tensorflow/core/kernels/resize_area_op.cc @@ -32,6 +32,17 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; +namespace { + +struct CachedInterpolation { + int64 start; + int64 end; + float start_scale; + float end_minus_one_scale; + bool needs_bounding; +}; +}; + template class ResizeAreaOp : public OpKernel { public: @@ -39,6 +50,99 @@ class ResizeAreaOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_)); } + // Computes the sum of all x values defined by taken across + // the y offsets and scales defined by y_ptrs and y_scales, for channel c. + // + // Note that is a template parameter to avoid a performance + // penalty from dynamically checking it. + template + static void ComputePatchSumOf3Channels(float scale, + const ImageResizerState& st, + const std::vector& y_ptrs, + const std::vector& y_scales, + const CachedInterpolation& x_interp, + float* output_ptr) { +#define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x)) + + float sum_0 = 0; + float sum_1 = 0; + float sum_2 = 0; + for (int i = 0; i < y_ptrs.size(); ++i) { + const T* ptr = y_ptrs[i]; + float scale_x = x_interp.start_scale; + int64 offset = 3 * BOUND_IF_NEEDED(x_interp.start, st.in_width); + float sum_y_0 = static_cast(ptr[offset + 0]) * scale_x; + float sum_y_1 = static_cast(ptr[offset + 1]) * scale_x; + float sum_y_2 = static_cast(ptr[offset + 2]) * scale_x; + + if (x_interp.start + 1 != x_interp.end) { + for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) { + int64 offset = 3 * BOUND_IF_NEEDED(x, st.in_width); + sum_y_0 += static_cast(ptr[offset + 0]); + sum_y_1 += static_cast(ptr[offset + 1]); + sum_y_2 += static_cast(ptr[offset + 2]); + } + scale_x = x_interp.end_minus_one_scale; + offset = 3 * BOUND_IF_NEEDED(x_interp.end - 1, st.in_width); + sum_y_0 += static_cast(ptr[offset + 0]) * scale_x; + sum_y_1 += static_cast(ptr[offset + 1]) * scale_x; + sum_y_2 += static_cast(ptr[offset + 2]) * scale_x; + } + sum_0 += sum_y_0 * y_scales[i]; + sum_1 += sum_y_1 * y_scales[i]; + sum_2 += sum_y_2 * y_scales[i]; + } + + output_ptr[0] = sum_0 * scale; + output_ptr[1] = sum_1 * scale; + output_ptr[2] = sum_2 * scale; + +#undef BOUND_IF_NEEDED + } + + // Computes the sum of all x values defined by taken across + // the y offsets and scales defined by y_ptrs and y_scales, for channel c. + // + // Note that is a template parameter to avoid a performance + // penalty from dynamically checking it. + template + static void ComputePatchSum(float scale, const ImageResizerState& st, + const std::vector& y_ptrs, + const std::vector& y_scales, + const CachedInterpolation& x_interp, + float* output_ptr) { +#define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x)) + + const auto num_channels = st.channels; + for (int64 c = 0; c < num_channels; ++c) { + float sum = 0; + for (int i = 0; i < y_ptrs.size(); ++i) { + const T* ptr = y_ptrs[i]; + float scale_x = x_interp.start_scale; + float sum_y = static_cast( + ptr[num_channels * + BOUND_IF_NEEDED(x_interp.start, st.in_width) + + c]) * + scale_x; + if (x_interp.start + 1 != x_interp.end) { + for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) { + sum_y += static_cast( + ptr[num_channels * BOUND_IF_NEEDED(x, st.in_width) + c]); + } + scale_x = x_interp.end_minus_one_scale; + sum_y += static_cast( + ptr[num_channels * + BOUND_IF_NEEDED(x_interp.end - 1, st.in_width) + + c]) * + scale_x; + } + sum += sum_y * y_scales[i]; + } + output_ptr[c] = sum * scale; + } +#undef BOUND_IF_NEEDED + } + void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); ImageResizerState st(align_corners_); @@ -47,16 +151,49 @@ class ResizeAreaOp : public OpKernel { if (!context->status().ok()) return; typename TTypes::ConstTensor input_data = input.tensor(); + + // Precompute values used when iterating over x coordinates within a row. + // Note that it may be useful to cache x_interps for a given + // ImageResizerState. + std::vector x_interps(st.out_width); + for (int64 x = 0; x < st.out_width; ++x) { + auto& x_interp = x_interps[x]; + const float in_x = x * st.width_scale; + const float in_x1 = (x + 1) * st.width_scale; + // The start and end width indices of all the cells that could + // contribute to the target cell. + int64 v = floor(in_x); + x_interp.start = v; + // TODO(cwhipkey): simplify this logic. + x_interp.start_scale = + v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x) + : (v + 1 > in_x1 ? in_x1 - v : 1.0); + + v = ceil(in_x1); + x_interp.end = ceil(in_x1); + v = x_interp.end - 1; + x_interp.end_minus_one_scale = + v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x) + : (v + 1 > in_x1 ? in_x1 - v : 1.0); + x_interp.needs_bounding = + Bound(x_interp.start, st.in_width) != x_interp.start || + Bound(x_interp.end - 1, st.in_width) != (x_interp.end - 1); + } + + if (st.channels == 3) { + ComputeLoop<3>(st, x_interps, input_data); + } else { + ComputeLoop<-1>(st, x_interps, input_data); + } + } + + template + void ComputeLoop(const ImageResizerState& st, + const std::vector& x_interps, + typename TTypes::ConstTensor input_data) { typename TTypes::Tensor output_data = st.output->tensor(); - // A temporary tensor for computing the sum. - Tensor sum_tensor; - OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum::value, - TensorShape({st.channels}), - &sum_tensor)); - typename TTypes::Tensor sum_data = sum_tensor.vec(); - // When using this algorithm for downsizing, the target pixel value is the // weighted average of all the source pixels. The weight is determined by // the contribution percentage of the source pixel. @@ -76,44 +213,58 @@ class ResizeAreaOp : public OpKernel { // out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale // out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale // out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale + const T* const input_ptr = input_data.data(); + std::vector y_scales; + std::vector y_ptrs; float scale = 1.0 / (st.height_scale * st.width_scale); + float* output_ptr = output_data.data(); for (int64 b = 0; b < st.batch_size; ++b) { for (int64 y = 0; y < st.out_height; ++y) { const float in_y = y * st.height_scale; const float in_y1 = (y + 1) * st.height_scale; // The start and end height indices of all the cells that could // contribute to the target cell. - int64 y_start = floor(in_y); - int64 y_end = ceil(in_y1); - - for (int64 x = 0; x < st.out_width; ++x) { - const float in_x = x * st.width_scale; - const float in_x1 = (x + 1) * st.width_scale; - // The start and end width indices of all the cells that could - // contribute to the target cell. - int64 x_start = floor(in_x); - int64 x_end = ceil(in_x1); - - sum_data.setConstant(0.0); - for (int64 i = y_start; i < y_end; ++i) { - float scale_y = - i < in_y ? (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y) - : (i + 1 > in_y1 ? in_y1 - i : 1.0); - for (int64 j = x_start; j < x_end; ++j) { - float scale_x = - j < in_x ? (j + 1 > in_x1 ? st.width_scale : j + 1 - in_x) - : (j + 1 > in_x1 ? in_x1 - j : 1.0); - for (int64 c = 0; c < st.channels; ++c) { -#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val)))) - sum_data(c) += float(input_data(b, BOUND(i, st.in_height), - BOUND(j, st.in_width), c)) * - scale_y * scale_x * scale; -#undef BOUND - } - } + const int64 y_start = floor(in_y); + const int64 y_end = ceil(in_y1); + y_scales.clear(); + y_ptrs.clear(); + for (int64 i = y_start; i < y_end; ++i) { + float scale_y; + if (i < in_y) { + scale_y = (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y); + } else { + scale_y = (i + 1 > in_y1 ? in_y1 - i : 1.0); } - for (int64 c = 0; c < st.channels; ++c) { - output_data(b, y, x, c) = sum_data(c); + // TODO(cwhipkey): can this data unified with CachedInterpolation? + y_scales.push_back(scale_y); + y_ptrs.push_back( + input_ptr + (b * st.in_height * st.in_width * st.channels + + Bound(i, st.in_height) * st.in_width * st.channels)); + } + + if (kKnownNumChannels == 3) { + for (int64 x = 0; x < st.out_width; ++x) { + const CachedInterpolation& x_interp = x_interps[x]; + if (x_interp.needs_bounding) { + ComputePatchSumOf3Channels(scale, st, y_ptrs, y_scales, + x_interp, output_ptr); + } else { + ComputePatchSumOf3Channels(scale, st, y_ptrs, y_scales, + x_interp, output_ptr); + } + output_ptr += 3; + } + } else { + for (int64 x = 0; x < st.out_width; ++x) { + const CachedInterpolation& x_interp = x_interps[x]; + if (x_interp.needs_bounding) { + ComputePatchSum(scale, st, y_ptrs, y_scales, x_interp, + output_ptr); + } else { + ComputePatchSum(scale, st, y_ptrs, y_scales, x_interp, + output_ptr); + } + output_ptr += st.channels; } } } @@ -121,6 +272,10 @@ class ResizeAreaOp : public OpKernel { } private: + static EIGEN_ALWAYS_INLINE int64 Bound(int64 val, int64 limit) { + return std::min(limit - 1ll, std::max(0ll, val)); + } + bool align_corners_; }; diff --git a/tensorflow/core/kernels/resize_area_op_test.cc b/tensorflow/core/kernels/resize_area_op_test.cc new file mode 100644 index 00000000000..415bce3cce6 --- /dev/null +++ b/tensorflow/core/kernels/resize_area_op_test.cc @@ -0,0 +1,197 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" + +namespace tensorflow { + +class ResizeAreaOpTest : public OpsTestBase { + protected: + ResizeAreaOpTest() { + TF_EXPECT_OK(NodeDefBuilder("resize_area_op", "ResizeArea") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Attr("align_corners", false) + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + } + + const Tensor* SetRandomImageInput(const TensorShape& shape) { + inputs_.clear(); + + CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions."; + bool is_ref = IsRefType(input_types_[inputs_.size()]); + Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), shape); + input->flat().setZero(); + tensors_.push_back(input); + if (is_ref) { + CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), + DataTypeToEnum::v()); + inputs_.push_back({&lock_for_refs_, input}); + } else { + CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum::v()); + inputs_.push_back({nullptr, input}); + } + return input; + } + + private: + // This is the unoptimized implementation of ResizeArea. + // We use this to confirm that the optimized version is exactly identical. + void ResizeAreaBaseline(TTypes::ConstTensor input_data, + TTypes::Tensor output_data) { + const int batch_size = input_data.dimension(0); + const int64 in_height = input_data.dimension(1); + const int64 in_width = input_data.dimension(2); + const int channels = input_data.dimension(3); + + ASSERT_EQ(batch_size, output_data.dimension(0)); + ASSERT_EQ(channels, output_data.dimension(3)); + + const int64 out_height = output_data.dimension(1); + const int64 out_width = output_data.dimension(2); + + const float height_scale = in_height / static_cast(out_height); + const float width_scale = in_width / static_cast(out_width); + + // A temporary tensor for computing the sum. + Tensor sum_tensor(DT_FLOAT, TensorShape({channels})); + typename TTypes::Tensor sum_data = sum_tensor.vec(); + + // When using this algorithm for downsizing, the target pixel value is the + // weighted average of all the source pixels. The weight is determined by + // the contribution percentage of the source pixel. + // + // Let "scale" be "target_image_size/source_image_size". If 1/n of the + // source pixel contributes to the target pixel, then the weight is (1/n * + // scale); if the complete source pixel contributes to the target pixel, + // then the weight is scale. + // + // To visualize the implementation, use one dimension as an example: + // Resize in[4] to out[3]. + // scale = 3/4 = 0.75 + // out[0]: in[0] and 1/3 of in[1] + // out[1]: 2/3 of in[1] and 2/3 of in[2] + // out[2]: 1/3 of in[2] and in[1] + // Hence, the output pixel values are: + // out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale + // out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale + // out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale + float scale = 1.0 / (height_scale * width_scale); + for (int64 b = 0; b < batch_size; ++b) { + for (int64 y = 0; y < out_height; ++y) { + const float in_y = y * height_scale; + const float in_y1 = (y + 1) * height_scale; + // The start and end height indices of all the cells that could + // contribute to the target cell. + int64 y_start = floor(in_y); + int64 y_end = ceil(in_y1); + + for (int64 x = 0; x < out_width; ++x) { + const float in_x = x * width_scale; + const float in_x1 = (x + 1) * width_scale; + // The start and end width indices of all the cells that could + // contribute to the target cell. + int64 x_start = floor(in_x); + int64 x_end = ceil(in_x1); + + sum_data.setConstant(0.0); + for (int64 i = y_start; i < y_end; ++i) { + float scale_y = i < in_y + ? (i + 1 > in_y1 ? height_scale : i + 1 - in_y) + : (i + 1 > in_y1 ? in_y1 - i : 1.0); + for (int64 j = x_start; j < x_end; ++j) { + float scale_x = j < in_x + ? (j + 1 > in_x1 ? width_scale : j + 1 - in_x) + : (j + 1 > in_x1 ? in_x1 - j : 1.0); + for (int64 c = 0; c < channels; ++c) { +#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val)))) + sum_data(c) += + static_cast(input_data(b, BOUND(i, in_height), + BOUND(j, in_width), c)) * + scale_y * scale_x * scale; +#undef BOUND + } + } + } + for (int64 c = 0; c < channels; ++c) { + output_data(b, y, x, c) = sum_data(c); + } + } + } + } + } + + protected: + void RunRandomTest(int in_height, int in_width, int target_height, + int target_width, int channels) { + const Tensor* input = + SetRandomImageInput(TensorShape({1, in_height, in_width, channels})); + AddInputFromArray(TensorShape({2}), {target_height, target_width}); + + TF_ASSERT_OK(RunOpKernel()); + std::unique_ptr expected( + new Tensor(device_->GetAllocator(AllocatorAttributes()), + DataTypeToEnum::v(), + TensorShape({1, target_height, target_width, channels}))); + ResizeAreaBaseline(input->tensor(), expected->tensor()); + test::ExpectTensorNear(*expected, *GetOutput(0), 0.00001); + } + + void RunManyRandomTests(int channels) { + for (int in_w : {2, 4, 7, 20, 165}) { + for (int in_h : {1, 3, 5, 8, 100, 233}) { + for (int target_height : {1, 2, 3, 50, 113}) { + for (int target_width : {target_height, target_height / 2 + 1}) { + RunRandomTest(in_w, in_h, target_height, target_width, channels); + } + } + } + } + } +}; + +TEST_F(ResizeAreaOpTest, TestAreaRandom141x186) { + RunRandomTest(141, 186, 299, 299, 3 /* channels */); +} + +TEST_F(ResizeAreaOpTest, TestAreaRandom183x229) { + RunRandomTest(183, 229, 299, 299, 3 /* channels */); +} + +TEST_F(ResizeAreaOpTest, TestAreaRandom749x603) { + RunRandomTest(749, 603, 299, 299, 3 /* channels */); +} + +TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes1Channel) { + RunManyRandomTests(1); +} + +TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes3Channels) { + RunManyRandomTests(3); +} + +TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) { + RunManyRandomTests(4); +} + +} // namespace tensorflow diff --git a/tensorflow/python/ops/image_ops_test.py b/tensorflow/python/ops/image_ops_test.py index c611271b6c1..8148af7a648 100644 --- a/tensorflow/python/ops/image_ops_test.py +++ b/tensorflow/python/ops/image_ops_test.py @@ -470,9 +470,7 @@ class ResizeBilinearBenchmark(test.Benchmark): print('Variables initalized for resize_bilinear image size: %s.' % (image_size,)) benchmark_values = self.run_op_benchmark( - sess, - benchmark_op, - name=('bilinear_%s_%s' % image_size),) + sess, benchmark_op, name=('bilinear_%s_%s' % image_size)) print('Benchmark values:\n%s' % benchmark_values) def benchmarkSimilar(self): @@ -506,9 +504,7 @@ class ResizeBicubicBenchmark(test.Benchmark): print('Variables initalized for resize_bicubic image size: %s.' % (image_size,)) benchmark_values = self.run_op_benchmark( - sess, - benchmark_op, - name=('bicubic_%s_%s' % image_size),) + sess, benchmark_op, name=('bicubic_%s_%s' % image_size)) print('Benchmark values:\n%s' % benchmark_values) def benchmarkSimilar(self): @@ -521,6 +517,52 @@ class ResizeBicubicBenchmark(test.Benchmark): self._benchmarkResize((749, 603)) +class ResizeAreaBenchmark(test.Benchmark): + + def _benchmarkResize(self, image_size, num_channels): + batch_size = 1 + num_ops = 1000 + img = variables.Variable( + random_ops.random_normal([batch_size, image_size[0], + image_size[1], num_channels]), + name='img') + + deps = [] + for _ in xrange(num_ops): + with ops.control_dependencies(deps): + resize_op = image_ops.resize_area(img, [299, 299], align_corners=False) + deps = [resize_op] + benchmark_op = control_flow_ops.group(*deps) + + with session.Session() as sess: + sess.run(variables.global_variables_initializer()) + results = self.run_op_benchmark( + sess, benchmark_op, + name=('resize_area_%s_%s_%s' % + (image_size[0], image_size[1], num_channels))) + print('%s : %.2f ms/img' % ( + results['name'], + 1000*results['wall_time'] / (batch_size * num_ops))) + + def benchmarkSimilar3Channel(self): + self._benchmarkResize((183, 229), 3) + + def benchmarkScaleUp3Channel(self): + self._benchmarkResize((141, 186), 3) + + def benchmarkScaleDown3Channel(self): + self._benchmarkResize((749, 603), 3) + + def benchmarkSimilar1Channel(self): + self._benchmarkResize((183, 229), 1) + + def benchmarkScaleUp1Channel(self): + self._benchmarkResize((141, 186), 1) + + def benchmarkScaleDown1Channel(self): + self._benchmarkResize((749, 603), 1) + + class AdjustSaturationTest(test_util.TensorFlowTestCase): def testHalfSaturation(self):