Improve performance of ResizeArea op:

- cache computed values, hoist computations out of loops
- avoid bounds checks in many cases.
- access input data with pointer offsets instead of through 4-d eigen
  tensor (which requires more complicated index computations).
- add custom accumulation fn for 3-channel images

Added tests to resize_area_op_test.cc, and benchmarks to image_ops_test.py.
Change: 146177262
This commit is contained in:
A. Unique TensorFlower 2017-01-31 15:53:38 -08:00 committed by TensorFlower Gardener
parent fd7d78ddf1
commit e2127701a5
4 changed files with 438 additions and 43 deletions

View File

@ -1680,6 +1680,7 @@ tf_cc_tests(
"colorspace_op_test.cc",
"crop_and_resize_op_test.cc",
"non_max_suppression_op_test.cc",
"resize_area_op_test.cc",
"resize_bicubic_op_test.cc",
"resize_bilinear_op_test.cc",
"resize_nearest_neighbor_op_test.cc",

View File

@ -32,6 +32,17 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
namespace {
struct CachedInterpolation {
int64 start;
int64 end;
float start_scale;
float end_minus_one_scale;
bool needs_bounding;
};
};
template <typename Device, typename T>
class ResizeAreaOp : public OpKernel {
public:
@ -39,6 +50,99 @@ class ResizeAreaOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
}
// Computes the sum of all x values defined by <x_interp> taken across
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
//
// Note that <NeedsXBounding> is a template parameter to avoid a performance
// penalty from dynamically checking it.
template <bool NeedsXBounding>
static void ComputePatchSumOf3Channels(float scale,
const ImageResizerState& st,
const std::vector<const T*>& y_ptrs,
const std::vector<float>& y_scales,
const CachedInterpolation& x_interp,
float* output_ptr) {
#define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x))
float sum_0 = 0;
float sum_1 = 0;
float sum_2 = 0;
for (int i = 0; i < y_ptrs.size(); ++i) {
const T* ptr = y_ptrs[i];
float scale_x = x_interp.start_scale;
int64 offset = 3 * BOUND_IF_NEEDED(x_interp.start, st.in_width);
float sum_y_0 = static_cast<float>(ptr[offset + 0]) * scale_x;
float sum_y_1 = static_cast<float>(ptr[offset + 1]) * scale_x;
float sum_y_2 = static_cast<float>(ptr[offset + 2]) * scale_x;
if (x_interp.start + 1 != x_interp.end) {
for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
int64 offset = 3 * BOUND_IF_NEEDED(x, st.in_width);
sum_y_0 += static_cast<float>(ptr[offset + 0]);
sum_y_1 += static_cast<float>(ptr[offset + 1]);
sum_y_2 += static_cast<float>(ptr[offset + 2]);
}
scale_x = x_interp.end_minus_one_scale;
offset = 3 * BOUND_IF_NEEDED(x_interp.end - 1, st.in_width);
sum_y_0 += static_cast<float>(ptr[offset + 0]) * scale_x;
sum_y_1 += static_cast<float>(ptr[offset + 1]) * scale_x;
sum_y_2 += static_cast<float>(ptr[offset + 2]) * scale_x;
}
sum_0 += sum_y_0 * y_scales[i];
sum_1 += sum_y_1 * y_scales[i];
sum_2 += sum_y_2 * y_scales[i];
}
output_ptr[0] = sum_0 * scale;
output_ptr[1] = sum_1 * scale;
output_ptr[2] = sum_2 * scale;
#undef BOUND_IF_NEEDED
}
// Computes the sum of all x values defined by <x_interp> taken across
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
//
// Note that <NeedsXBounding> is a template parameter to avoid a performance
// penalty from dynamically checking it.
template <bool NeedsXBounding>
static void ComputePatchSum(float scale, const ImageResizerState& st,
const std::vector<const T*>& y_ptrs,
const std::vector<float>& y_scales,
const CachedInterpolation& x_interp,
float* output_ptr) {
#define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x))
const auto num_channels = st.channels;
for (int64 c = 0; c < num_channels; ++c) {
float sum = 0;
for (int i = 0; i < y_ptrs.size(); ++i) {
const T* ptr = y_ptrs[i];
float scale_x = x_interp.start_scale;
float sum_y = static_cast<float>(
ptr[num_channels *
BOUND_IF_NEEDED(x_interp.start, st.in_width) +
c]) *
scale_x;
if (x_interp.start + 1 != x_interp.end) {
for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
sum_y += static_cast<float>(
ptr[num_channels * BOUND_IF_NEEDED(x, st.in_width) + c]);
}
scale_x = x_interp.end_minus_one_scale;
sum_y += static_cast<float>(
ptr[num_channels *
BOUND_IF_NEEDED(x_interp.end - 1, st.in_width) +
c]) *
scale_x;
}
sum += sum_y * y_scales[i];
}
output_ptr[c] = sum * scale;
}
#undef BOUND_IF_NEEDED
}
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
ImageResizerState st(align_corners_);
@ -47,16 +151,49 @@ class ResizeAreaOp : public OpKernel {
if (!context->status().ok()) return;
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
// Precompute values used when iterating over x coordinates within a row.
// Note that it may be useful to cache x_interps for a given
// ImageResizerState.
std::vector<CachedInterpolation> x_interps(st.out_width);
for (int64 x = 0; x < st.out_width; ++x) {
auto& x_interp = x_interps[x];
const float in_x = x * st.width_scale;
const float in_x1 = (x + 1) * st.width_scale;
// The start and end width indices of all the cells that could
// contribute to the target cell.
int64 v = floor(in_x);
x_interp.start = v;
// TODO(cwhipkey): simplify this logic.
x_interp.start_scale =
v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x)
: (v + 1 > in_x1 ? in_x1 - v : 1.0);
v = ceil(in_x1);
x_interp.end = ceil(in_x1);
v = x_interp.end - 1;
x_interp.end_minus_one_scale =
v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x)
: (v + 1 > in_x1 ? in_x1 - v : 1.0);
x_interp.needs_bounding =
Bound(x_interp.start, st.in_width) != x_interp.start ||
Bound(x_interp.end - 1, st.in_width) != (x_interp.end - 1);
}
if (st.channels == 3) {
ComputeLoop<3>(st, x_interps, input_data);
} else {
ComputeLoop<-1>(st, x_interps, input_data);
}
}
template <int64 kKnownNumChannels>
void ComputeLoop(const ImageResizerState& st,
const std::vector<CachedInterpolation>& x_interps,
typename TTypes<T, 4>::ConstTensor input_data) {
typename TTypes<float, 4>::Tensor output_data =
st.output->tensor<float, 4>();
// A temporary tensor for computing the sum.
Tensor sum_tensor;
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::value,
TensorShape({st.channels}),
&sum_tensor));
typename TTypes<float, 1>::Tensor sum_data = sum_tensor.vec<float>();
// When using this algorithm for downsizing, the target pixel value is the
// weighted average of all the source pixels. The weight is determined by
// the contribution percentage of the source pixel.
@ -76,44 +213,58 @@ class ResizeAreaOp : public OpKernel {
// out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
// out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
// out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
const T* const input_ptr = input_data.data();
std::vector<float> y_scales;
std::vector<const T*> y_ptrs;
float scale = 1.0 / (st.height_scale * st.width_scale);
float* output_ptr = output_data.data();
for (int64 b = 0; b < st.batch_size; ++b) {
for (int64 y = 0; y < st.out_height; ++y) {
const float in_y = y * st.height_scale;
const float in_y1 = (y + 1) * st.height_scale;
// The start and end height indices of all the cells that could
// contribute to the target cell.
int64 y_start = floor(in_y);
int64 y_end = ceil(in_y1);
for (int64 x = 0; x < st.out_width; ++x) {
const float in_x = x * st.width_scale;
const float in_x1 = (x + 1) * st.width_scale;
// The start and end width indices of all the cells that could
// contribute to the target cell.
int64 x_start = floor(in_x);
int64 x_end = ceil(in_x1);
sum_data.setConstant(0.0);
const int64 y_start = floor(in_y);
const int64 y_end = ceil(in_y1);
y_scales.clear();
y_ptrs.clear();
for (int64 i = y_start; i < y_end; ++i) {
float scale_y =
i < in_y ? (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y)
: (i + 1 > in_y1 ? in_y1 - i : 1.0);
for (int64 j = x_start; j < x_end; ++j) {
float scale_x =
j < in_x ? (j + 1 > in_x1 ? st.width_scale : j + 1 - in_x)
: (j + 1 > in_x1 ? in_x1 - j : 1.0);
for (int64 c = 0; c < st.channels; ++c) {
#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
sum_data(c) += float(input_data(b, BOUND(i, st.in_height),
BOUND(j, st.in_width), c)) *
scale_y * scale_x * scale;
#undef BOUND
float scale_y;
if (i < in_y) {
scale_y = (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y);
} else {
scale_y = (i + 1 > in_y1 ? in_y1 - i : 1.0);
}
// TODO(cwhipkey): can this data unified with CachedInterpolation?
y_scales.push_back(scale_y);
y_ptrs.push_back(
input_ptr + (b * st.in_height * st.in_width * st.channels +
Bound(i, st.in_height) * st.in_width * st.channels));
}
if (kKnownNumChannels == 3) {
for (int64 x = 0; x < st.out_width; ++x) {
const CachedInterpolation& x_interp = x_interps[x];
if (x_interp.needs_bounding) {
ComputePatchSumOf3Channels<true>(scale, st, y_ptrs, y_scales,
x_interp, output_ptr);
} else {
ComputePatchSumOf3Channels<false>(scale, st, y_ptrs, y_scales,
x_interp, output_ptr);
}
for (int64 c = 0; c < st.channels; ++c) {
output_data(b, y, x, c) = sum_data(c);
output_ptr += 3;
}
} else {
for (int64 x = 0; x < st.out_width; ++x) {
const CachedInterpolation& x_interp = x_interps[x];
if (x_interp.needs_bounding) {
ComputePatchSum<true>(scale, st, y_ptrs, y_scales, x_interp,
output_ptr);
} else {
ComputePatchSum<false>(scale, st, y_ptrs, y_scales, x_interp,
output_ptr);
}
output_ptr += st.channels;
}
}
}
@ -121,6 +272,10 @@ class ResizeAreaOp : public OpKernel {
}
private:
static EIGEN_ALWAYS_INLINE int64 Bound(int64 val, int64 limit) {
return std::min(limit - 1ll, std::max(0ll, val));
}
bool align_corners_;
};

View File

@ -0,0 +1,197 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
class ResizeAreaOpTest : public OpsTestBase {
protected:
ResizeAreaOpTest() {
TF_EXPECT_OK(NodeDefBuilder("resize_area_op", "ResizeArea")
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_INT32))
.Attr("align_corners", false)
.Finalize(node_def()));
TF_EXPECT_OK(InitOp());
}
const Tensor* SetRandomImageInput(const TensorShape& shape) {
inputs_.clear();
CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions.";
bool is_ref = IsRefType(input_types_[inputs_.size()]);
Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
DataTypeToEnum<float>::v(), shape);
input->flat<float>().setZero();
tensors_.push_back(input);
if (is_ref) {
CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
DataTypeToEnum<float>::v());
inputs_.push_back({&lock_for_refs_, input});
} else {
CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<float>::v());
inputs_.push_back({nullptr, input});
}
return input;
}
private:
// This is the unoptimized implementation of ResizeArea.
// We use this to confirm that the optimized version is exactly identical.
void ResizeAreaBaseline(TTypes<float, 4>::ConstTensor input_data,
TTypes<float, 4>::Tensor output_data) {
const int batch_size = input_data.dimension(0);
const int64 in_height = input_data.dimension(1);
const int64 in_width = input_data.dimension(2);
const int channels = input_data.dimension(3);
ASSERT_EQ(batch_size, output_data.dimension(0));
ASSERT_EQ(channels, output_data.dimension(3));
const int64 out_height = output_data.dimension(1);
const int64 out_width = output_data.dimension(2);
const float height_scale = in_height / static_cast<float>(out_height);
const float width_scale = in_width / static_cast<float>(out_width);
// A temporary tensor for computing the sum.
Tensor sum_tensor(DT_FLOAT, TensorShape({channels}));
typename TTypes<float, 1>::Tensor sum_data = sum_tensor.vec<float>();
// When using this algorithm for downsizing, the target pixel value is the
// weighted average of all the source pixels. The weight is determined by
// the contribution percentage of the source pixel.
//
// Let "scale" be "target_image_size/source_image_size". If 1/n of the
// source pixel contributes to the target pixel, then the weight is (1/n *
// scale); if the complete source pixel contributes to the target pixel,
// then the weight is scale.
//
// To visualize the implementation, use one dimension as an example:
// Resize in[4] to out[3].
// scale = 3/4 = 0.75
// out[0]: in[0] and 1/3 of in[1]
// out[1]: 2/3 of in[1] and 2/3 of in[2]
// out[2]: 1/3 of in[2] and in[1]
// Hence, the output pixel values are:
// out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
// out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
// out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
float scale = 1.0 / (height_scale * width_scale);
for (int64 b = 0; b < batch_size; ++b) {
for (int64 y = 0; y < out_height; ++y) {
const float in_y = y * height_scale;
const float in_y1 = (y + 1) * height_scale;
// The start and end height indices of all the cells that could
// contribute to the target cell.
int64 y_start = floor(in_y);
int64 y_end = ceil(in_y1);
for (int64 x = 0; x < out_width; ++x) {
const float in_x = x * width_scale;
const float in_x1 = (x + 1) * width_scale;
// The start and end width indices of all the cells that could
// contribute to the target cell.
int64 x_start = floor(in_x);
int64 x_end = ceil(in_x1);
sum_data.setConstant(0.0);
for (int64 i = y_start; i < y_end; ++i) {
float scale_y = i < in_y
? (i + 1 > in_y1 ? height_scale : i + 1 - in_y)
: (i + 1 > in_y1 ? in_y1 - i : 1.0);
for (int64 j = x_start; j < x_end; ++j) {
float scale_x = j < in_x
? (j + 1 > in_x1 ? width_scale : j + 1 - in_x)
: (j + 1 > in_x1 ? in_x1 - j : 1.0);
for (int64 c = 0; c < channels; ++c) {
#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
sum_data(c) +=
static_cast<float>(input_data(b, BOUND(i, in_height),
BOUND(j, in_width), c)) *
scale_y * scale_x * scale;
#undef BOUND
}
}
}
for (int64 c = 0; c < channels; ++c) {
output_data(b, y, x, c) = sum_data(c);
}
}
}
}
}
protected:
void RunRandomTest(int in_height, int in_width, int target_height,
int target_width, int channels) {
const Tensor* input =
SetRandomImageInput(TensorShape({1, in_height, in_width, channels}));
AddInputFromArray<int32>(TensorShape({2}), {target_height, target_width});
TF_ASSERT_OK(RunOpKernel());
std::unique_ptr<Tensor> expected(
new Tensor(device_->GetAllocator(AllocatorAttributes()),
DataTypeToEnum<float>::v(),
TensorShape({1, target_height, target_width, channels})));
ResizeAreaBaseline(input->tensor<float, 4>(), expected->tensor<float, 4>());
test::ExpectTensorNear<float>(*expected, *GetOutput(0), 0.00001);
}
void RunManyRandomTests(int channels) {
for (int in_w : {2, 4, 7, 20, 165}) {
for (int in_h : {1, 3, 5, 8, 100, 233}) {
for (int target_height : {1, 2, 3, 50, 113}) {
for (int target_width : {target_height, target_height / 2 + 1}) {
RunRandomTest(in_w, in_h, target_height, target_width, channels);
}
}
}
}
}
};
TEST_F(ResizeAreaOpTest, TestAreaRandom141x186) {
RunRandomTest(141, 186, 299, 299, 3 /* channels */);
}
TEST_F(ResizeAreaOpTest, TestAreaRandom183x229) {
RunRandomTest(183, 229, 299, 299, 3 /* channels */);
}
TEST_F(ResizeAreaOpTest, TestAreaRandom749x603) {
RunRandomTest(749, 603, 299, 299, 3 /* channels */);
}
TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes1Channel) {
RunManyRandomTests(1);
}
TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes3Channels) {
RunManyRandomTests(3);
}
TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) {
RunManyRandomTests(4);
}
} // namespace tensorflow

View File

@ -470,9 +470,7 @@ class ResizeBilinearBenchmark(test.Benchmark):
print('Variables initalized for resize_bilinear image size: %s.' %
(image_size,))
benchmark_values = self.run_op_benchmark(
sess,
benchmark_op,
name=('bilinear_%s_%s' % image_size),)
sess, benchmark_op, name=('bilinear_%s_%s' % image_size))
print('Benchmark values:\n%s' % benchmark_values)
def benchmarkSimilar(self):
@ -506,9 +504,7 @@ class ResizeBicubicBenchmark(test.Benchmark):
print('Variables initalized for resize_bicubic image size: %s.' %
(image_size,))
benchmark_values = self.run_op_benchmark(
sess,
benchmark_op,
name=('bicubic_%s_%s' % image_size),)
sess, benchmark_op, name=('bicubic_%s_%s' % image_size))
print('Benchmark values:\n%s' % benchmark_values)
def benchmarkSimilar(self):
@ -521,6 +517,52 @@ class ResizeBicubicBenchmark(test.Benchmark):
self._benchmarkResize((749, 603))
class ResizeAreaBenchmark(test.Benchmark):
def _benchmarkResize(self, image_size, num_channels):
batch_size = 1
num_ops = 1000
img = variables.Variable(
random_ops.random_normal([batch_size, image_size[0],
image_size[1], num_channels]),
name='img')
deps = []
for _ in xrange(num_ops):
with ops.control_dependencies(deps):
resize_op = image_ops.resize_area(img, [299, 299], align_corners=False)
deps = [resize_op]
benchmark_op = control_flow_ops.group(*deps)
with session.Session() as sess:
sess.run(variables.global_variables_initializer())
results = self.run_op_benchmark(
sess, benchmark_op,
name=('resize_area_%s_%s_%s' %
(image_size[0], image_size[1], num_channels)))
print('%s : %.2f ms/img' % (
results['name'],
1000*results['wall_time'] / (batch_size * num_ops)))
def benchmarkSimilar3Channel(self):
self._benchmarkResize((183, 229), 3)
def benchmarkScaleUp3Channel(self):
self._benchmarkResize((141, 186), 3)
def benchmarkScaleDown3Channel(self):
self._benchmarkResize((749, 603), 3)
def benchmarkSimilar1Channel(self):
self._benchmarkResize((183, 229), 1)
def benchmarkScaleUp1Channel(self):
self._benchmarkResize((141, 186), 1)
def benchmarkScaleDown1Channel(self):
self._benchmarkResize((749, 603), 1)
class AdjustSaturationTest(test_util.TensorFlowTestCase):
def testHalfSaturation(self):