Improve performance of ResizeArea op:
- cache computed values, hoist computations out of loops - avoid bounds checks in many cases. - access input data with pointer offsets instead of through 4-d eigen tensor (which requires more complicated index computations). - add custom accumulation fn for 3-channel images Added tests to resize_area_op_test.cc, and benchmarks to image_ops_test.py. Change: 146177262
This commit is contained in:
parent
fd7d78ddf1
commit
e2127701a5
@ -1680,6 +1680,7 @@ tf_cc_tests(
|
||||
"colorspace_op_test.cc",
|
||||
"crop_and_resize_op_test.cc",
|
||||
"non_max_suppression_op_test.cc",
|
||||
"resize_area_op_test.cc",
|
||||
"resize_bicubic_op_test.cc",
|
||||
"resize_bilinear_op_test.cc",
|
||||
"resize_nearest_neighbor_op_test.cc",
|
||||
|
@ -32,6 +32,17 @@ namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
namespace {
|
||||
|
||||
struct CachedInterpolation {
|
||||
int64 start;
|
||||
int64 end;
|
||||
float start_scale;
|
||||
float end_minus_one_scale;
|
||||
bool needs_bounding;
|
||||
};
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class ResizeAreaOp : public OpKernel {
|
||||
public:
|
||||
@ -39,6 +50,99 @@ class ResizeAreaOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("align_corners", &align_corners_));
|
||||
}
|
||||
|
||||
// Computes the sum of all x values defined by <x_interp> taken across
|
||||
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
|
||||
//
|
||||
// Note that <NeedsXBounding> is a template parameter to avoid a performance
|
||||
// penalty from dynamically checking it.
|
||||
template <bool NeedsXBounding>
|
||||
static void ComputePatchSumOf3Channels(float scale,
|
||||
const ImageResizerState& st,
|
||||
const std::vector<const T*>& y_ptrs,
|
||||
const std::vector<float>& y_scales,
|
||||
const CachedInterpolation& x_interp,
|
||||
float* output_ptr) {
|
||||
#define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x))
|
||||
|
||||
float sum_0 = 0;
|
||||
float sum_1 = 0;
|
||||
float sum_2 = 0;
|
||||
for (int i = 0; i < y_ptrs.size(); ++i) {
|
||||
const T* ptr = y_ptrs[i];
|
||||
float scale_x = x_interp.start_scale;
|
||||
int64 offset = 3 * BOUND_IF_NEEDED(x_interp.start, st.in_width);
|
||||
float sum_y_0 = static_cast<float>(ptr[offset + 0]) * scale_x;
|
||||
float sum_y_1 = static_cast<float>(ptr[offset + 1]) * scale_x;
|
||||
float sum_y_2 = static_cast<float>(ptr[offset + 2]) * scale_x;
|
||||
|
||||
if (x_interp.start + 1 != x_interp.end) {
|
||||
for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
int64 offset = 3 * BOUND_IF_NEEDED(x, st.in_width);
|
||||
sum_y_0 += static_cast<float>(ptr[offset + 0]);
|
||||
sum_y_1 += static_cast<float>(ptr[offset + 1]);
|
||||
sum_y_2 += static_cast<float>(ptr[offset + 2]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
offset = 3 * BOUND_IF_NEEDED(x_interp.end - 1, st.in_width);
|
||||
sum_y_0 += static_cast<float>(ptr[offset + 0]) * scale_x;
|
||||
sum_y_1 += static_cast<float>(ptr[offset + 1]) * scale_x;
|
||||
sum_y_2 += static_cast<float>(ptr[offset + 2]) * scale_x;
|
||||
}
|
||||
sum_0 += sum_y_0 * y_scales[i];
|
||||
sum_1 += sum_y_1 * y_scales[i];
|
||||
sum_2 += sum_y_2 * y_scales[i];
|
||||
}
|
||||
|
||||
output_ptr[0] = sum_0 * scale;
|
||||
output_ptr[1] = sum_1 * scale;
|
||||
output_ptr[2] = sum_2 * scale;
|
||||
|
||||
#undef BOUND_IF_NEEDED
|
||||
}
|
||||
|
||||
// Computes the sum of all x values defined by <x_interp> taken across
|
||||
// the y offsets and scales defined by y_ptrs and y_scales, for channel c.
|
||||
//
|
||||
// Note that <NeedsXBounding> is a template parameter to avoid a performance
|
||||
// penalty from dynamically checking it.
|
||||
template <bool NeedsXBounding>
|
||||
static void ComputePatchSum(float scale, const ImageResizerState& st,
|
||||
const std::vector<const T*>& y_ptrs,
|
||||
const std::vector<float>& y_scales,
|
||||
const CachedInterpolation& x_interp,
|
||||
float* output_ptr) {
|
||||
#define BOUND_IF_NEEDED(x, y) (NeedsXBounding ? Bound(x, y) : (x))
|
||||
|
||||
const auto num_channels = st.channels;
|
||||
for (int64 c = 0; c < num_channels; ++c) {
|
||||
float sum = 0;
|
||||
for (int i = 0; i < y_ptrs.size(); ++i) {
|
||||
const T* ptr = y_ptrs[i];
|
||||
float scale_x = x_interp.start_scale;
|
||||
float sum_y = static_cast<float>(
|
||||
ptr[num_channels *
|
||||
BOUND_IF_NEEDED(x_interp.start, st.in_width) +
|
||||
c]) *
|
||||
scale_x;
|
||||
if (x_interp.start + 1 != x_interp.end) {
|
||||
for (int64 x = x_interp.start + 1; x < x_interp.end - 1; ++x) {
|
||||
sum_y += static_cast<float>(
|
||||
ptr[num_channels * BOUND_IF_NEEDED(x, st.in_width) + c]);
|
||||
}
|
||||
scale_x = x_interp.end_minus_one_scale;
|
||||
sum_y += static_cast<float>(
|
||||
ptr[num_channels *
|
||||
BOUND_IF_NEEDED(x_interp.end - 1, st.in_width) +
|
||||
c]) *
|
||||
scale_x;
|
||||
}
|
||||
sum += sum_y * y_scales[i];
|
||||
}
|
||||
output_ptr[c] = sum * scale;
|
||||
}
|
||||
#undef BOUND_IF_NEEDED
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
const Tensor& input = context->input(0);
|
||||
ImageResizerState st(align_corners_);
|
||||
@ -47,16 +151,49 @@ class ResizeAreaOp : public OpKernel {
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
typename TTypes<T, 4>::ConstTensor input_data = input.tensor<T, 4>();
|
||||
|
||||
// Precompute values used when iterating over x coordinates within a row.
|
||||
// Note that it may be useful to cache x_interps for a given
|
||||
// ImageResizerState.
|
||||
std::vector<CachedInterpolation> x_interps(st.out_width);
|
||||
for (int64 x = 0; x < st.out_width; ++x) {
|
||||
auto& x_interp = x_interps[x];
|
||||
const float in_x = x * st.width_scale;
|
||||
const float in_x1 = (x + 1) * st.width_scale;
|
||||
// The start and end width indices of all the cells that could
|
||||
// contribute to the target cell.
|
||||
int64 v = floor(in_x);
|
||||
x_interp.start = v;
|
||||
// TODO(cwhipkey): simplify this logic.
|
||||
x_interp.start_scale =
|
||||
v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x)
|
||||
: (v + 1 > in_x1 ? in_x1 - v : 1.0);
|
||||
|
||||
v = ceil(in_x1);
|
||||
x_interp.end = ceil(in_x1);
|
||||
v = x_interp.end - 1;
|
||||
x_interp.end_minus_one_scale =
|
||||
v < in_x ? (v + 1 > in_x1 ? st.width_scale : v + 1 - in_x)
|
||||
: (v + 1 > in_x1 ? in_x1 - v : 1.0);
|
||||
x_interp.needs_bounding =
|
||||
Bound(x_interp.start, st.in_width) != x_interp.start ||
|
||||
Bound(x_interp.end - 1, st.in_width) != (x_interp.end - 1);
|
||||
}
|
||||
|
||||
if (st.channels == 3) {
|
||||
ComputeLoop<3>(st, x_interps, input_data);
|
||||
} else {
|
||||
ComputeLoop<-1>(st, x_interps, input_data);
|
||||
}
|
||||
}
|
||||
|
||||
template <int64 kKnownNumChannels>
|
||||
void ComputeLoop(const ImageResizerState& st,
|
||||
const std::vector<CachedInterpolation>& x_interps,
|
||||
typename TTypes<T, 4>::ConstTensor input_data) {
|
||||
typename TTypes<float, 4>::Tensor output_data =
|
||||
st.output->tensor<float, 4>();
|
||||
|
||||
// A temporary tensor for computing the sum.
|
||||
Tensor sum_tensor;
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<float>::value,
|
||||
TensorShape({st.channels}),
|
||||
&sum_tensor));
|
||||
typename TTypes<float, 1>::Tensor sum_data = sum_tensor.vec<float>();
|
||||
|
||||
// When using this algorithm for downsizing, the target pixel value is the
|
||||
// weighted average of all the source pixels. The weight is determined by
|
||||
// the contribution percentage of the source pixel.
|
||||
@ -76,44 +213,58 @@ class ResizeAreaOp : public OpKernel {
|
||||
// out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
|
||||
// out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
|
||||
// out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
|
||||
const T* const input_ptr = input_data.data();
|
||||
std::vector<float> y_scales;
|
||||
std::vector<const T*> y_ptrs;
|
||||
float scale = 1.0 / (st.height_scale * st.width_scale);
|
||||
float* output_ptr = output_data.data();
|
||||
for (int64 b = 0; b < st.batch_size; ++b) {
|
||||
for (int64 y = 0; y < st.out_height; ++y) {
|
||||
const float in_y = y * st.height_scale;
|
||||
const float in_y1 = (y + 1) * st.height_scale;
|
||||
// The start and end height indices of all the cells that could
|
||||
// contribute to the target cell.
|
||||
int64 y_start = floor(in_y);
|
||||
int64 y_end = ceil(in_y1);
|
||||
|
||||
for (int64 x = 0; x < st.out_width; ++x) {
|
||||
const float in_x = x * st.width_scale;
|
||||
const float in_x1 = (x + 1) * st.width_scale;
|
||||
// The start and end width indices of all the cells that could
|
||||
// contribute to the target cell.
|
||||
int64 x_start = floor(in_x);
|
||||
int64 x_end = ceil(in_x1);
|
||||
|
||||
sum_data.setConstant(0.0);
|
||||
const int64 y_start = floor(in_y);
|
||||
const int64 y_end = ceil(in_y1);
|
||||
y_scales.clear();
|
||||
y_ptrs.clear();
|
||||
for (int64 i = y_start; i < y_end; ++i) {
|
||||
float scale_y =
|
||||
i < in_y ? (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y)
|
||||
: (i + 1 > in_y1 ? in_y1 - i : 1.0);
|
||||
for (int64 j = x_start; j < x_end; ++j) {
|
||||
float scale_x =
|
||||
j < in_x ? (j + 1 > in_x1 ? st.width_scale : j + 1 - in_x)
|
||||
: (j + 1 > in_x1 ? in_x1 - j : 1.0);
|
||||
for (int64 c = 0; c < st.channels; ++c) {
|
||||
#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
|
||||
sum_data(c) += float(input_data(b, BOUND(i, st.in_height),
|
||||
BOUND(j, st.in_width), c)) *
|
||||
scale_y * scale_x * scale;
|
||||
#undef BOUND
|
||||
float scale_y;
|
||||
if (i < in_y) {
|
||||
scale_y = (i + 1 > in_y1 ? st.height_scale : i + 1 - in_y);
|
||||
} else {
|
||||
scale_y = (i + 1 > in_y1 ? in_y1 - i : 1.0);
|
||||
}
|
||||
// TODO(cwhipkey): can this data unified with CachedInterpolation?
|
||||
y_scales.push_back(scale_y);
|
||||
y_ptrs.push_back(
|
||||
input_ptr + (b * st.in_height * st.in_width * st.channels +
|
||||
Bound(i, st.in_height) * st.in_width * st.channels));
|
||||
}
|
||||
|
||||
if (kKnownNumChannels == 3) {
|
||||
for (int64 x = 0; x < st.out_width; ++x) {
|
||||
const CachedInterpolation& x_interp = x_interps[x];
|
||||
if (x_interp.needs_bounding) {
|
||||
ComputePatchSumOf3Channels<true>(scale, st, y_ptrs, y_scales,
|
||||
x_interp, output_ptr);
|
||||
} else {
|
||||
ComputePatchSumOf3Channels<false>(scale, st, y_ptrs, y_scales,
|
||||
x_interp, output_ptr);
|
||||
}
|
||||
for (int64 c = 0; c < st.channels; ++c) {
|
||||
output_data(b, y, x, c) = sum_data(c);
|
||||
output_ptr += 3;
|
||||
}
|
||||
} else {
|
||||
for (int64 x = 0; x < st.out_width; ++x) {
|
||||
const CachedInterpolation& x_interp = x_interps[x];
|
||||
if (x_interp.needs_bounding) {
|
||||
ComputePatchSum<true>(scale, st, y_ptrs, y_scales, x_interp,
|
||||
output_ptr);
|
||||
} else {
|
||||
ComputePatchSum<false>(scale, st, y_ptrs, y_scales, x_interp,
|
||||
output_ptr);
|
||||
}
|
||||
output_ptr += st.channels;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -121,6 +272,10 @@ class ResizeAreaOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
static EIGEN_ALWAYS_INLINE int64 Bound(int64 val, int64 limit) {
|
||||
return std::min(limit - 1ll, std::max(0ll, val));
|
||||
}
|
||||
|
||||
bool align_corners_;
|
||||
};
|
||||
|
||||
|
197
tensorflow/core/kernels/resize_area_op_test.cc
Normal file
197
tensorflow/core/kernels/resize_area_op_test.cc
Normal file
@ -0,0 +1,197 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class ResizeAreaOpTest : public OpsTestBase {
|
||||
protected:
|
||||
ResizeAreaOpTest() {
|
||||
TF_EXPECT_OK(NodeDefBuilder("resize_area_op", "ResizeArea")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Attr("align_corners", false)
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
}
|
||||
|
||||
const Tensor* SetRandomImageInput(const TensorShape& shape) {
|
||||
inputs_.clear();
|
||||
|
||||
CHECK_EQ(shape.dims(), 4) << "All images must have 4 dimensions.";
|
||||
bool is_ref = IsRefType(input_types_[inputs_.size()]);
|
||||
Tensor* input = new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(), shape);
|
||||
input->flat<float>().setZero();
|
||||
tensors_.push_back(input);
|
||||
if (is_ref) {
|
||||
CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]),
|
||||
DataTypeToEnum<float>::v());
|
||||
inputs_.push_back({&lock_for_refs_, input});
|
||||
} else {
|
||||
CHECK_EQ(input_types_[inputs_.size()], DataTypeToEnum<float>::v());
|
||||
inputs_.push_back({nullptr, input});
|
||||
}
|
||||
return input;
|
||||
}
|
||||
|
||||
private:
|
||||
// This is the unoptimized implementation of ResizeArea.
|
||||
// We use this to confirm that the optimized version is exactly identical.
|
||||
void ResizeAreaBaseline(TTypes<float, 4>::ConstTensor input_data,
|
||||
TTypes<float, 4>::Tensor output_data) {
|
||||
const int batch_size = input_data.dimension(0);
|
||||
const int64 in_height = input_data.dimension(1);
|
||||
const int64 in_width = input_data.dimension(2);
|
||||
const int channels = input_data.dimension(3);
|
||||
|
||||
ASSERT_EQ(batch_size, output_data.dimension(0));
|
||||
ASSERT_EQ(channels, output_data.dimension(3));
|
||||
|
||||
const int64 out_height = output_data.dimension(1);
|
||||
const int64 out_width = output_data.dimension(2);
|
||||
|
||||
const float height_scale = in_height / static_cast<float>(out_height);
|
||||
const float width_scale = in_width / static_cast<float>(out_width);
|
||||
|
||||
// A temporary tensor for computing the sum.
|
||||
Tensor sum_tensor(DT_FLOAT, TensorShape({channels}));
|
||||
typename TTypes<float, 1>::Tensor sum_data = sum_tensor.vec<float>();
|
||||
|
||||
// When using this algorithm for downsizing, the target pixel value is the
|
||||
// weighted average of all the source pixels. The weight is determined by
|
||||
// the contribution percentage of the source pixel.
|
||||
//
|
||||
// Let "scale" be "target_image_size/source_image_size". If 1/n of the
|
||||
// source pixel contributes to the target pixel, then the weight is (1/n *
|
||||
// scale); if the complete source pixel contributes to the target pixel,
|
||||
// then the weight is scale.
|
||||
//
|
||||
// To visualize the implementation, use one dimension as an example:
|
||||
// Resize in[4] to out[3].
|
||||
// scale = 3/4 = 0.75
|
||||
// out[0]: in[0] and 1/3 of in[1]
|
||||
// out[1]: 2/3 of in[1] and 2/3 of in[2]
|
||||
// out[2]: 1/3 of in[2] and in[1]
|
||||
// Hence, the output pixel values are:
|
||||
// out[0] = (in[0] * 1.0 + in[1] * 1/3) * scale
|
||||
// out[1] = (in[1] * 2/3 + in[2] * 2/3 * scale
|
||||
// out[2] = (in[3] * 1/3 + in[3] * 1.0) * scale
|
||||
float scale = 1.0 / (height_scale * width_scale);
|
||||
for (int64 b = 0; b < batch_size; ++b) {
|
||||
for (int64 y = 0; y < out_height; ++y) {
|
||||
const float in_y = y * height_scale;
|
||||
const float in_y1 = (y + 1) * height_scale;
|
||||
// The start and end height indices of all the cells that could
|
||||
// contribute to the target cell.
|
||||
int64 y_start = floor(in_y);
|
||||
int64 y_end = ceil(in_y1);
|
||||
|
||||
for (int64 x = 0; x < out_width; ++x) {
|
||||
const float in_x = x * width_scale;
|
||||
const float in_x1 = (x + 1) * width_scale;
|
||||
// The start and end width indices of all the cells that could
|
||||
// contribute to the target cell.
|
||||
int64 x_start = floor(in_x);
|
||||
int64 x_end = ceil(in_x1);
|
||||
|
||||
sum_data.setConstant(0.0);
|
||||
for (int64 i = y_start; i < y_end; ++i) {
|
||||
float scale_y = i < in_y
|
||||
? (i + 1 > in_y1 ? height_scale : i + 1 - in_y)
|
||||
: (i + 1 > in_y1 ? in_y1 - i : 1.0);
|
||||
for (int64 j = x_start; j < x_end; ++j) {
|
||||
float scale_x = j < in_x
|
||||
? (j + 1 > in_x1 ? width_scale : j + 1 - in_x)
|
||||
: (j + 1 > in_x1 ? in_x1 - j : 1.0);
|
||||
for (int64 c = 0; c < channels; ++c) {
|
||||
#define BOUND(val, limit) std::min(((limit)-1ll), (std::max(0ll, (val))))
|
||||
sum_data(c) +=
|
||||
static_cast<float>(input_data(b, BOUND(i, in_height),
|
||||
BOUND(j, in_width), c)) *
|
||||
scale_y * scale_x * scale;
|
||||
#undef BOUND
|
||||
}
|
||||
}
|
||||
}
|
||||
for (int64 c = 0; c < channels; ++c) {
|
||||
output_data(b, y, x, c) = sum_data(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
void RunRandomTest(int in_height, int in_width, int target_height,
|
||||
int target_width, int channels) {
|
||||
const Tensor* input =
|
||||
SetRandomImageInput(TensorShape({1, in_height, in_width, channels}));
|
||||
AddInputFromArray<int32>(TensorShape({2}), {target_height, target_width});
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
std::unique_ptr<Tensor> expected(
|
||||
new Tensor(device_->GetAllocator(AllocatorAttributes()),
|
||||
DataTypeToEnum<float>::v(),
|
||||
TensorShape({1, target_height, target_width, channels})));
|
||||
ResizeAreaBaseline(input->tensor<float, 4>(), expected->tensor<float, 4>());
|
||||
test::ExpectTensorNear<float>(*expected, *GetOutput(0), 0.00001);
|
||||
}
|
||||
|
||||
void RunManyRandomTests(int channels) {
|
||||
for (int in_w : {2, 4, 7, 20, 165}) {
|
||||
for (int in_h : {1, 3, 5, 8, 100, 233}) {
|
||||
for (int target_height : {1, 2, 3, 50, 113}) {
|
||||
for (int target_width : {target_height, target_height / 2 + 1}) {
|
||||
RunRandomTest(in_w, in_h, target_height, target_width, channels);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ResizeAreaOpTest, TestAreaRandom141x186) {
|
||||
RunRandomTest(141, 186, 299, 299, 3 /* channels */);
|
||||
}
|
||||
|
||||
TEST_F(ResizeAreaOpTest, TestAreaRandom183x229) {
|
||||
RunRandomTest(183, 229, 299, 299, 3 /* channels */);
|
||||
}
|
||||
|
||||
TEST_F(ResizeAreaOpTest, TestAreaRandom749x603) {
|
||||
RunRandomTest(749, 603, 299, 299, 3 /* channels */);
|
||||
}
|
||||
|
||||
TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes1Channel) {
|
||||
RunManyRandomTests(1);
|
||||
}
|
||||
|
||||
TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes3Channels) {
|
||||
RunManyRandomTests(3);
|
||||
}
|
||||
|
||||
TEST_F(ResizeAreaOpTest, TestAreaRandomDataSeveralInputsSizes4Channels) {
|
||||
RunManyRandomTests(4);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -470,9 +470,7 @@ class ResizeBilinearBenchmark(test.Benchmark):
|
||||
print('Variables initalized for resize_bilinear image size: %s.' %
|
||||
(image_size,))
|
||||
benchmark_values = self.run_op_benchmark(
|
||||
sess,
|
||||
benchmark_op,
|
||||
name=('bilinear_%s_%s' % image_size),)
|
||||
sess, benchmark_op, name=('bilinear_%s_%s' % image_size))
|
||||
print('Benchmark values:\n%s' % benchmark_values)
|
||||
|
||||
def benchmarkSimilar(self):
|
||||
@ -506,9 +504,7 @@ class ResizeBicubicBenchmark(test.Benchmark):
|
||||
print('Variables initalized for resize_bicubic image size: %s.' %
|
||||
(image_size,))
|
||||
benchmark_values = self.run_op_benchmark(
|
||||
sess,
|
||||
benchmark_op,
|
||||
name=('bicubic_%s_%s' % image_size),)
|
||||
sess, benchmark_op, name=('bicubic_%s_%s' % image_size))
|
||||
print('Benchmark values:\n%s' % benchmark_values)
|
||||
|
||||
def benchmarkSimilar(self):
|
||||
@ -521,6 +517,52 @@ class ResizeBicubicBenchmark(test.Benchmark):
|
||||
self._benchmarkResize((749, 603))
|
||||
|
||||
|
||||
class ResizeAreaBenchmark(test.Benchmark):
|
||||
|
||||
def _benchmarkResize(self, image_size, num_channels):
|
||||
batch_size = 1
|
||||
num_ops = 1000
|
||||
img = variables.Variable(
|
||||
random_ops.random_normal([batch_size, image_size[0],
|
||||
image_size[1], num_channels]),
|
||||
name='img')
|
||||
|
||||
deps = []
|
||||
for _ in xrange(num_ops):
|
||||
with ops.control_dependencies(deps):
|
||||
resize_op = image_ops.resize_area(img, [299, 299], align_corners=False)
|
||||
deps = [resize_op]
|
||||
benchmark_op = control_flow_ops.group(*deps)
|
||||
|
||||
with session.Session() as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
results = self.run_op_benchmark(
|
||||
sess, benchmark_op,
|
||||
name=('resize_area_%s_%s_%s' %
|
||||
(image_size[0], image_size[1], num_channels)))
|
||||
print('%s : %.2f ms/img' % (
|
||||
results['name'],
|
||||
1000*results['wall_time'] / (batch_size * num_ops)))
|
||||
|
||||
def benchmarkSimilar3Channel(self):
|
||||
self._benchmarkResize((183, 229), 3)
|
||||
|
||||
def benchmarkScaleUp3Channel(self):
|
||||
self._benchmarkResize((141, 186), 3)
|
||||
|
||||
def benchmarkScaleDown3Channel(self):
|
||||
self._benchmarkResize((749, 603), 3)
|
||||
|
||||
def benchmarkSimilar1Channel(self):
|
||||
self._benchmarkResize((183, 229), 1)
|
||||
|
||||
def benchmarkScaleUp1Channel(self):
|
||||
self._benchmarkResize((141, 186), 1)
|
||||
|
||||
def benchmarkScaleDown1Channel(self):
|
||||
self._benchmarkResize((749, 603), 1)
|
||||
|
||||
|
||||
class AdjustSaturationTest(test_util.TensorFlowTestCase):
|
||||
|
||||
def testHalfSaturation(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user