Fuse resize and mirror padding ops into convolutions
Spatial transformations like padding and bilinear resizing can be merged into the im2col stage of conv2d. This reduces the memory usage considerably (from 338MB to 224MB) and latency (by 15%) on some models, and helps us avoid OOM crashes on iOS. This PR has all the changes needed to fuse these particular ops, including the kernels themselves and integration into the optimize_for_inference script. Change: 132094335
This commit is contained in:
parent
10451eb6cf
commit
cb324446ac
@ -101,6 +101,7 @@ tensorflow/core/kernels/cwise_op_div.cc
|
||||
tensorflow/core/kernels/cwise_op_add.cc
|
||||
tensorflow/core/kernels/ctc_decoder_ops.cc
|
||||
tensorflow/core/kernels/conv_ops_using_gemm.cc
|
||||
tensorflow/core/kernels/conv_ops_fused.cc
|
||||
tensorflow/core/kernels/conv_ops.cc
|
||||
tensorflow/core/kernels/conv_grad_ops.cc
|
||||
tensorflow/core/kernels/control_flow_ops.cc
|
||||
|
@ -491,6 +491,27 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "conv_ops_test",
|
||||
size = "small",
|
||||
deps = [
|
||||
":conv_ops",
|
||||
":image",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "example_parsing_ops_test",
|
||||
size = "large",
|
||||
@ -1325,6 +1346,7 @@ tf_kernel_library(
|
||||
hdrs = [
|
||||
"conv_grad_ops.h",
|
||||
"deep_conv2d.h",
|
||||
"gemm_functors.h",
|
||||
"winograd_transform.h",
|
||||
],
|
||||
prefix = "conv_ops",
|
||||
@ -1332,6 +1354,7 @@ tf_kernel_library(
|
||||
":bounds_check",
|
||||
":conv_2d",
|
||||
":conv_3d",
|
||||
":image_resizer_state",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_KERNELS_CONV_OPS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
@ -38,6 +39,16 @@ class LaunchConv2DOp {
|
||||
TensorFormat data_format);
|
||||
};
|
||||
|
||||
// Used to keep track of persistent memory buffers used within the op.
|
||||
template <class T, size_t size>
|
||||
struct Im2ColBufferResource : public ResourceBase {
|
||||
// This mutex ensures that only a single operation at a time is able to use
|
||||
// the buffer memory held by this resource.
|
||||
mutex mu;
|
||||
T data[size];
|
||||
string DebugString() { return "Im2ColBufferResource"; }
|
||||
};
|
||||
|
||||
#ifdef GOOGLE_CUDA
|
||||
template <typename T>
|
||||
class LaunchConv2DOp<Eigen::GpuDevice, T> {
|
||||
|
486
tensorflow/core/kernels/conv_ops_fused.cc
Normal file
486
tensorflow/core/kernels/conv_ops_fused.cc
Normal file
@ -0,0 +1,486 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Implements convolution operations with other kernels baked into the
|
||||
// processing, to optimize latency and memory usage.
|
||||
|
||||
#include <string.h>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/conv_ops.h"
|
||||
#include "tensorflow/core/kernels/gemm_functors.h"
|
||||
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Combines bilinear resizing and mirror padding into the im2col transformation
|
||||
// stage of convolution,
|
||||
template <class T1, class T2, class T3, class TGemmFunctor>
|
||||
class FusedResizeAndPadConvFunctor {
|
||||
public:
|
||||
void operator()(OpKernelContext* context, const Tensor& input,
|
||||
int input_batches, int resized_height, int resized_width,
|
||||
int padded_height, int padded_width, int input_depth,
|
||||
const T2* filter_data, int filter_height, int filter_width,
|
||||
int filter_count, int stride_rows, int stride_cols,
|
||||
Padding padding, T3* output_data, int output_height,
|
||||
int output_width, const ImageResizerState& st,
|
||||
int top_padding, int bottom_padding, int left_padding,
|
||||
int right_padding, int pad_offset) {
|
||||
if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
|
||||
(input_depth <= 0)) {
|
||||
LOG(WARNING) << "Conv2D was called with bad input dimensions: "
|
||||
<< input_batches << ", " << padded_height << ", "
|
||||
<< padded_width << ", " << input_depth;
|
||||
return;
|
||||
}
|
||||
if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
|
||||
LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
|
||||
<< filter_width << ", " << filter_height << ", "
|
||||
<< filter_count;
|
||||
return;
|
||||
}
|
||||
if ((output_width <= 0) || (output_height <= 0)) {
|
||||
LOG(WARNING) << "Conv2D was called with bad output width or height: "
|
||||
<< output_width << ", " << output_height;
|
||||
return;
|
||||
}
|
||||
|
||||
// These calculations define how the patches will be positioned within the
|
||||
// input image. The actual definitions are quite complex, and rely on the
|
||||
// previously-calculated output size.
|
||||
int filter_left_offset;
|
||||
int filter_top_offset;
|
||||
if (padding == VALID) {
|
||||
filter_left_offset =
|
||||
((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
|
||||
2;
|
||||
filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
|
||||
padded_height + 1) /
|
||||
2;
|
||||
} else {
|
||||
filter_left_offset =
|
||||
((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
|
||||
filter_top_offset =
|
||||
((output_height - 1) * stride_rows + filter_height - padded_height) /
|
||||
2;
|
||||
}
|
||||
|
||||
// The im2col buffer has # of patches rows, and # of filters cols.
|
||||
// It's laid out like this, in row major order in memory:
|
||||
// < filter value count >
|
||||
// ^ +---------------------+
|
||||
// patch | |
|
||||
// count | |
|
||||
// v +---------------------+
|
||||
// Each patch row contains a filter_width x filter_height patch of the
|
||||
// input, with the depth channel as the most contiguous in memory, followed
|
||||
// by the width, then the height. This is the standard memory order in the
|
||||
// image world if it helps to visualize it.
|
||||
const int filter_value_count = filter_width * filter_height * input_depth;
|
||||
|
||||
// We don't want to allocate a buffer to hold all the patches if the size is
|
||||
// going to be extremely large, so break it into chunks if it's bigger than
|
||||
// a limit. Each chunk will be processed serially, so we can refill the
|
||||
// buffer for the next chunk and reuse it, keeping maximum memory size down.
|
||||
// In this case, we've picked 16 megabytes as a reasonable limit.
|
||||
const size_t max_chunk_size = (16 * 1024 * 1024);
|
||||
OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size,
|
||||
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
||||
const size_t patches_per_chunk =
|
||||
max_chunk_size / (filter_value_count * sizeof(T1));
|
||||
// Because memory allocation is very expensive on mobile platforms, try to
|
||||
// allocate a persistent buffer that will be kept around between calls. We
|
||||
// use TensorFlow's resource management to ensure that the memory will be
|
||||
// released when the session is over.
|
||||
Im2ColBufferResource<T1, max_chunk_size>* im2col_buffer_resource;
|
||||
std::function<Status(Im2ColBufferResource<T1, max_chunk_size>**)> creator =
|
||||
[](Im2ColBufferResource<T1, max_chunk_size>** resource) {
|
||||
*resource = new Im2ColBufferResource<T1, max_chunk_size>();
|
||||
return Status::OK();
|
||||
};
|
||||
OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
|
||||
"Conv2d", "im2col_buffer",
|
||||
&im2col_buffer_resource, creator));
|
||||
// This means that multiple ops can't be run simultaneously on different
|
||||
// threads, because we have a single shared resource. The platforms this is
|
||||
// aimed at have intra-op parallelism as their focus though, so it shouldn't
|
||||
// be an issue.
|
||||
mutex_lock lock_buffer(im2col_buffer_resource->mu);
|
||||
core::ScopedUnref unref_buffer(im2col_buffer_resource);
|
||||
T1* im2col_buffer = im2col_buffer_resource->data;
|
||||
|
||||
typename TTypes<T1, 4>::ConstTensor input_data = input.tensor<T1, 4>();
|
||||
|
||||
for (int batch = 0; batch < input_batches; ++batch) {
|
||||
for (int out_y = 0; out_y < output_height; ++out_y) {
|
||||
const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
|
||||
for (int out_x = 0; out_x < output_width; ++out_x) {
|
||||
const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
|
||||
const int patch_index = (batch * output_width * output_height) +
|
||||
(out_y * output_width) + out_x;
|
||||
const int patch_index_within_chunk = patch_index % patches_per_chunk;
|
||||
T1* im2col_patch_start =
|
||||
im2col_buffer + (patch_index_within_chunk * filter_value_count);
|
||||
for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
|
||||
const int conv_in_y = in_y_origin + filter_y;
|
||||
float in_y = (conv_in_y - top_padding);
|
||||
if (in_y < 0) {
|
||||
in_y = -(in_y + 1.0f - pad_offset);
|
||||
} else if (in_y >= resized_height) {
|
||||
in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
|
||||
}
|
||||
in_y *= st.height_scale;
|
||||
const int64 top_y_index = static_cast<int64>(std::floor(in_y));
|
||||
const int64 bottom_y_index = std::min(
|
||||
static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
|
||||
const T1 y_lerp = in_y - top_y_index;
|
||||
T1* im2col_row_start =
|
||||
im2col_patch_start + (filter_y * filter_width * input_depth);
|
||||
for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
|
||||
const int conv_in_x = in_x_origin + filter_x;
|
||||
float in_x = (conv_in_x - left_padding);
|
||||
if (in_x < 0) {
|
||||
in_x = -(in_x + 1.0f - pad_offset);
|
||||
} else if (in_x >= resized_width) {
|
||||
in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
|
||||
}
|
||||
in_x *= st.width_scale;
|
||||
const int64 left_x_index = static_cast<int64>(std::floor(in_x));
|
||||
const int64 right_x_index = std::min(
|
||||
static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
|
||||
const T1 x_lerp = in_x - left_x_index;
|
||||
T1* im2col_row_pixel =
|
||||
im2col_row_start + (filter_x * input_depth);
|
||||
for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
|
||||
T1 in_value;
|
||||
if ((conv_in_x >= 0) && (conv_in_x < padded_width) &&
|
||||
(conv_in_y >= 0) && (conv_in_y < padded_height)) {
|
||||
const T1 top_left(
|
||||
input_data(batch, top_y_index, left_x_index, in_channel));
|
||||
const T1 top_right(input_data(batch, top_y_index,
|
||||
right_x_index, in_channel));
|
||||
const T1 bottom_left(input_data(batch, bottom_y_index,
|
||||
left_x_index, in_channel));
|
||||
const T1 bottom_right(input_data(batch, bottom_y_index,
|
||||
right_x_index, in_channel));
|
||||
const T1 top = top_left + (top_right - top_left) * x_lerp;
|
||||
const T1 bottom =
|
||||
bottom_left + (bottom_right - bottom_left) * x_lerp;
|
||||
in_value = top + (bottom - top) * y_lerp;
|
||||
} else {
|
||||
in_value = T1(0);
|
||||
}
|
||||
im2col_row_pixel[in_channel] = in_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
const bool is_last_in_chunk =
|
||||
(patch_index_within_chunk == (patches_per_chunk - 1));
|
||||
const bool is_last_overall =
|
||||
((batch == (input_batches - 1)) &&
|
||||
(out_y == (output_height - 1)) && (out_x == (output_width - 1)));
|
||||
if (is_last_in_chunk || is_last_overall) {
|
||||
// Now we've assembled a set of image patches into a matrix, apply a
|
||||
// GEMM matrix multiply of the patches as rows, times the filter
|
||||
// weights in columns, to get partial results in the output matrix.
|
||||
const int how_many_patches = patch_index_within_chunk + 1;
|
||||
const int m = how_many_patches;
|
||||
const int n = filter_count;
|
||||
const int k = filter_value_count;
|
||||
const int lda = filter_value_count;
|
||||
const int ldb = filter_count;
|
||||
const int ldc = filter_count;
|
||||
const size_t start_patch_index =
|
||||
patch_index - (how_many_patches - 1);
|
||||
T3* chunk_output_data =
|
||||
output_data + (start_patch_index * filter_count);
|
||||
TGemmFunctor gemm_functor;
|
||||
gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb,
|
||||
chunk_output_data, ldc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
// Implements a version of convolution with bilinear resizing and mirror padding
|
||||
// included.
|
||||
template <class T, class TConvFunctor>
|
||||
class FusedResizeConv2DUsingGemmOp : public OpKernel {
|
||||
public:
|
||||
explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("resize_align_corners", &align_corners_));
|
||||
MirrorPadMode mode;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
|
||||
|
||||
switch (mode) {
|
||||
case MirrorPadMode::SYMMETRIC: {
|
||||
offset_ = 0;
|
||||
break;
|
||||
}
|
||||
case MirrorPadMode::REFLECT: {
|
||||
offset_ = 1;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
OP_REQUIRES(context, false,
|
||||
errors::InvalidArgument(
|
||||
"mode must be either REFLECT or SYMMETRIC."));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
OP_REQUIRES(context, strides_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
|
||||
const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
|
||||
OP_REQUIRES(
|
||||
context, stride_n == 1 && stride_c == 1,
|
||||
errors::InvalidArgument("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
// Input tensor is of the following dimensions:
|
||||
// [ batch, in_rows, in_cols, in_depth ]
|
||||
const Tensor& input = context->input(0);
|
||||
OP_REQUIRES(context, (input.shape().num_elements() > 0),
|
||||
errors::InvalidArgument("Input tensor can't be empty"));
|
||||
|
||||
ImageResizerState st(align_corners_);
|
||||
st.ValidateAndCalculateOutputSize(context, input);
|
||||
if (!context->status().ok()) return;
|
||||
const TensorShape resized_shape(
|
||||
{input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
|
||||
|
||||
const Tensor& paddings = context->input(2);
|
||||
|
||||
const int dims = resized_shape.dims();
|
||||
OP_REQUIRES(
|
||||
context, TensorShapeUtils::IsMatrix(paddings.shape()) &&
|
||||
paddings.dim_size(1) == 2,
|
||||
errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
|
||||
paddings.shape().DebugString()));
|
||||
const int fixed_dims =
|
||||
(allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
|
||||
? 1
|
||||
: dims;
|
||||
OP_REQUIRES(
|
||||
context, fixed_dims == paddings.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs: ",
|
||||
fixed_dims, " ", paddings.shape().DebugString(), " ",
|
||||
resized_shape.DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, dims == paddings.dim_size(0),
|
||||
errors::InvalidArgument(
|
||||
"The first dimension of paddings must be the rank of inputs: ",
|
||||
dims, " ", paddings.shape().DebugString(), " ",
|
||||
resized_shape.DebugString()));
|
||||
|
||||
OP_REQUIRES(
|
||||
context, dims == 4,
|
||||
errors::InvalidArgument(
|
||||
"Fused mirror padding only supports four-dimensional inputs, but ",
|
||||
dims, " requested"));
|
||||
|
||||
// Compute the shape of the output tensor, and allocate it.
|
||||
TensorShape padded_shape;
|
||||
TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
|
||||
for (int d = 0; d < dims; ++d) {
|
||||
const int32 before =
|
||||
paddings_matrix(d, 0); // Pad before existing elements.
|
||||
const int32 after =
|
||||
paddings_matrix(d, 1); // Pad after exisitng elements.
|
||||
OP_REQUIRES(context, before >= 0 && after >= 0,
|
||||
errors::InvalidArgument("paddings must be non-negative: ",
|
||||
before, " ", after));
|
||||
if (offset_ == 0) { // SYMMETRIC mode.
|
||||
OP_REQUIRES(
|
||||
context, before <= resized_shape.dim_size(d) &&
|
||||
after <= resized_shape.dim_size(d),
|
||||
errors::InvalidArgument("paddings must be no greater "
|
||||
"than the dimension size: ",
|
||||
before, ", ", after, " greater than ",
|
||||
resized_shape.dim_size(d)));
|
||||
} else if (offset_ == 1) { // REFLECT mode.
|
||||
OP_REQUIRES(
|
||||
context, before < resized_shape.dim_size(d) &&
|
||||
after < resized_shape.dim_size(d),
|
||||
errors::InvalidArgument("paddings must be less than"
|
||||
" the dimension size: ",
|
||||
before, ", ", after, " not less than ",
|
||||
resized_shape.dim_size(d)));
|
||||
}
|
||||
padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
|
||||
}
|
||||
|
||||
OP_REQUIRES(
|
||||
context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
|
||||
errors::InvalidArgument(
|
||||
"Fused mirror padding only support spatial padding, not batches: ",
|
||||
paddings.DebugString()));
|
||||
OP_REQUIRES(
|
||||
context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
|
||||
errors::InvalidArgument(
|
||||
"Fused mirror padding only support spatial padding, not channels: ",
|
||||
paddings.DebugString()));
|
||||
const int32 top_padding = paddings_matrix(1, 0);
|
||||
const int32 bottom_padding = paddings_matrix(1, 1);
|
||||
const int32 left_padding = paddings_matrix(2, 0);
|
||||
const int32 right_padding = paddings_matrix(2, 1);
|
||||
|
||||
// Input filter is of the following dimensions:
|
||||
// [ filter_rows, filter_cols, in_depth, out_depth]
|
||||
const Tensor& filter = context->input(3);
|
||||
|
||||
// For 2D convolution, there should be 4 dimensions.
|
||||
OP_REQUIRES(context, padded_shape.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
padded_shape.DebugString()));
|
||||
OP_REQUIRES(context, filter.dims() == 4,
|
||||
errors::InvalidArgument("filter must be 4-dimensional: ",
|
||||
filter.shape().DebugString()));
|
||||
|
||||
// We only check the first three dims, since the depth is accessed as an
|
||||
// int64 below.
|
||||
for (int i = 0; i < 3; i++) {
|
||||
OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("filter too large"));
|
||||
}
|
||||
|
||||
// The last dimension for input is in_depth. It must be the same as the
|
||||
// filter's in_depth.
|
||||
const int64 in_depth = padded_shape.dim_size(3);
|
||||
OP_REQUIRES(
|
||||
context, in_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||
in_depth, " vs ", filter.dim_size(2)));
|
||||
|
||||
// The last dimension for filter is out_depth.
|
||||
const int out_depth = static_cast<int>(filter.dim_size(3));
|
||||
|
||||
// The second dimension for input is rows/height.
|
||||
// The first dimension for filter is rows/height.
|
||||
const int64 padded_rows_raw = padded_shape.dim_size(1);
|
||||
OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input rows too large"));
|
||||
const int padded_rows = static_cast<int>(padded_rows_raw);
|
||||
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
||||
const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
|
||||
|
||||
// The third dimension for input is columns/width.
|
||||
// The second dimension for filter is columns/width.
|
||||
const int64 padded_cols_raw = padded_shape.dim_size(2);
|
||||
OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input cols too large"));
|
||||
const int padded_cols = static_cast<int>(padded_cols_raw);
|
||||
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
||||
const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
|
||||
|
||||
// The first dimension for input is batch.
|
||||
const int64 batch_raw = padded_shape.dim_size(0);
|
||||
OP_REQUIRES(context,
|
||||
FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("batch is too large"));
|
||||
const int batch = static_cast<int>(batch_raw);
|
||||
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
// do not support striding on the batch or depth dimension).
|
||||
const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
|
||||
const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
|
||||
|
||||
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
|
||||
padding_, &out_rows, &pad_rows));
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
|
||||
padding_, &out_cols, &pad_cols));
|
||||
TensorShape out_shape =
|
||||
ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
|
||||
OP_REQUIRES(context, (out_shape.num_elements() > 0),
|
||||
errors::InvalidArgument("Output tensor can't be empty"));
|
||||
|
||||
// Output tensor is of the following dimensions:
|
||||
// [ in_batch, out_rows, out_cols, out_depth ]
|
||||
Tensor* output = nullptr;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
|
||||
|
||||
VLOG(2) << "Conv2D: in_depth = " << in_depth
|
||||
<< ", padded_cols = " << padded_cols
|
||||
<< ", filter_cols = " << filter_cols
|
||||
<< ", padded_rows = " << padded_rows
|
||||
<< ", filter_rows = " << filter_rows
|
||||
<< ", stride_rows = " << stride_rows
|
||||
<< ", stride_cols = " << stride_cols
|
||||
<< ", out_depth = " << out_depth;
|
||||
|
||||
// If there is nothing to compute, return.
|
||||
if (out_shape.num_elements() == 0) {
|
||||
return;
|
||||
}
|
||||
TConvFunctor conv_functor;
|
||||
conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
|
||||
padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
|
||||
filter_cols, out_depth, stride_rows, stride_cols, padding_,
|
||||
output->flat<T>().data(), out_rows, out_cols, st, top_padding,
|
||||
bottom_padding, left_padding, right_padding, offset_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
bool align_corners_;
|
||||
int offset_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
|
||||
};
|
||||
|
||||
#define REGISTER_FUSED(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("FusedResizeAndPadConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<T>("T"), \
|
||||
FusedResizeConv2DUsingGemmOp< \
|
||||
T, \
|
||||
FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
|
||||
|
||||
TF_CALL_float(REGISTER_FUSED);
|
||||
|
||||
} // namespace tensorflow
|
240
tensorflow/core/kernels/conv_ops_test.cc
Normal file
240
tensorflow/core/kernels/conv_ops_test.cc
Normal file
@ -0,0 +1,240 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class FusedResizePadConvOpTest : public OpsTestBase {
|
||||
protected:
|
||||
void HandwrittenConv() {
|
||||
const int stride = 1;
|
||||
TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(DT_INT32))
|
||||
.Input(FakeInput(DT_FLOAT))
|
||||
.Attr("T", DT_FLOAT)
|
||||
.Attr("resize_align_corners", false)
|
||||
.Attr("mode", "REFLECT")
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 3;
|
||||
const int image_batch_count = 1;
|
||||
// The image matrix is:
|
||||
// | 1 | 2 | 3 | 4 |
|
||||
// | 5 | 6 | 7 | 8 |
|
||||
// | 9 | 10 | 11 | 12 |
|
||||
Tensor image(DT_FLOAT,
|
||||
{image_batch_count, image_height, image_width, depth});
|
||||
test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
|
||||
// The filter matrix is:
|
||||
// | 1 | 4 | 7 |
|
||||
// | 2 | 5 | 8 |
|
||||
// | 3 | 6 | 9 |
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
|
||||
test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
|
||||
|
||||
const int resized_width = image_width;
|
||||
const int resized_height = image_height;
|
||||
|
||||
const int top_padding = 0;
|
||||
const int bottom_padding = 0;
|
||||
const int left_padding = 0;
|
||||
const int right_padding = 0;
|
||||
|
||||
AddInputFromArray<float>(image.shape(), image.flat<float>());
|
||||
AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
|
||||
AddInputFromArray<int32>(
|
||||
TensorShape({4, 2}),
|
||||
{0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
|
||||
AddInputFromArray<float>(filter.shape(), filter.flat<float>());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// We're sliding the 3x3 filter across the 3x4 image, with accesses outside
|
||||
// the input set to zero because we're using the 'SAME' padding mode.
|
||||
// The calculations behind the expected output are:
|
||||
// (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
|
||||
// (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
|
||||
// (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
|
||||
// (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
|
||||
// (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
|
||||
// (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
|
||||
// (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
|
||||
// (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
|
||||
// (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
|
||||
// (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
|
||||
// (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
|
||||
// (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
|
||||
// This means we should end up with this matrix:
|
||||
// | 105 | 150 | 183 | 95 |
|
||||
// | 235 | 312 | 357 | 178 |
|
||||
// | 187 | 234 | 261 | 121 |
|
||||
const int expected_width = image_width;
|
||||
const int expected_height = image_height * filter_count;
|
||||
Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
|
||||
expected_width, filter_count}));
|
||||
test::FillValues<float>(
|
||||
&expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
|
||||
const Tensor& output = *GetOutput(0);
|
||||
test::ExpectTensorNear<float>(expected, output, 1e-5);
|
||||
}
|
||||
|
||||
void CompareFusedAndSeparate(int input_width, int input_height,
|
||||
int input_depth, int resize_width,
|
||||
int resize_height, int y_padding, int x_padding,
|
||||
int filter_size, int filter_count,
|
||||
bool resize_align_corners, string pad_mode,
|
||||
int stride, string padding) {
|
||||
auto root = tensorflow::Scope::NewRootScope();
|
||||
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
|
||||
|
||||
const size_t input_data_size = input_height * input_width * input_depth;
|
||||
Tensor input_data(DT_FLOAT,
|
||||
TensorShape({1, input_height, input_width, input_depth}));
|
||||
for (int i = 0; i < input_data_size; ++i) {
|
||||
input_data.flat<float>()(i) = i + 1.0f;
|
||||
}
|
||||
Output input =
|
||||
Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
|
||||
const size_t filter_data_size =
|
||||
filter_size * filter_size * filter_count * input_depth;
|
||||
Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
|
||||
input_depth, filter_count}));
|
||||
for (int i = 0; i < filter_data_size; ++i) {
|
||||
filter_data.flat<float>()(i) = i + 1.0f;
|
||||
}
|
||||
Output filter =
|
||||
Const(root.WithOpName("filter"), Input::Initializer(filter_data));
|
||||
|
||||
Output resize_size =
|
||||
Const(root.WithOpName("resize_size"), {resize_height, resize_width});
|
||||
Output resize =
|
||||
ResizeBilinear(root.WithOpName("resize"), input, resize_size,
|
||||
ResizeBilinear::AlignCorners(resize_align_corners));
|
||||
Output paddings =
|
||||
Const(root.WithOpName("paddings"),
|
||||
{{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
|
||||
Output mirror_pad =
|
||||
MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode);
|
||||
Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
|
||||
{1, stride, stride, 1}, padding);
|
||||
|
||||
Output fused_conv = FusedResizeAndPadConv2D(
|
||||
root.WithOpName("fused_conv"), input, resize_size, paddings, filter,
|
||||
pad_mode, {1, stride, stride, 1}, padding,
|
||||
FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
|
||||
|
||||
tensorflow::GraphDef graph;
|
||||
TF_ASSERT_OK(root.ToGraphDef(&graph));
|
||||
|
||||
std::unique_ptr<tensorflow::Session> session(
|
||||
tensorflow::NewSession(tensorflow::SessionOptions()));
|
||||
TF_ASSERT_OK(session->Create(graph));
|
||||
|
||||
std::vector<Tensor> unfused_tensors;
|
||||
TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
|
||||
|
||||
std::vector<Tensor> fused_tensors;
|
||||
TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
|
||||
|
||||
test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, IdentityComparative) {
|
||||
CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
|
||||
CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
|
||||
CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
|
||||
CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
|
||||
"VALID");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
|
||||
CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1,
|
||||
"SAME");
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -56,14 +56,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/conv_ops.h"
|
||||
#include "tensorflow/core/kernels/gemm_functors.h"
|
||||
#include "tensorflow/core/kernels/image_resizer_state.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#if defined(__APPLE__)
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#define USE_ACCELERATE_GEMM
|
||||
#endif // __APPLE__
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
@ -189,87 +188,6 @@ class ReferenceConvFunctor {
|
||||
}
|
||||
};
|
||||
|
||||
// A readable but slow implementation of matrix multiplication, useful for
|
||||
// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
|
||||
// the Im2ColConvFunctor template definition inside the op registration to
|
||||
// enable. Assumes row-major ordering of the values in memory.
|
||||
template <class T1, class T2, class T3>
|
||||
class ReferenceGemmFunctor {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
const size_t a_i_stride = lda;
|
||||
const size_t a_l_stride = 1;
|
||||
const size_t b_j_stride = 1;
|
||||
const size_t b_l_stride = ldb;
|
||||
const size_t c_i_stride = ldc;
|
||||
const size_t c_j_stride = 1;
|
||||
size_t i, j, l;
|
||||
for (j = 0; j < n; j++) {
|
||||
for (i = 0; i < m; i++) {
|
||||
T3 total(0);
|
||||
for (l = 0; l < k; l++) {
|
||||
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
|
||||
const T1 a_value = a[a_index];
|
||||
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
|
||||
const T2 b_value = b[b_index];
|
||||
total += (a_value * b_value);
|
||||
}
|
||||
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
|
||||
c[c_index] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Uses the optimized Eigen library to implement the matrix multiplication
|
||||
// required by the Im2ColConvFunctor class. We supply the two input and one
|
||||
// output types so that the accumulator can potentially be higher-precision than
|
||||
// the inputs, even though we don't currently take advantage of this.
|
||||
template <class T1, class T2, class T3>
|
||||
class FastGemmFunctor {
|
||||
public:
|
||||
// Convenience wrappers for the Eigen matrix types we'll be using.
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT1;
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT2;
|
||||
typedef Eigen::Map<
|
||||
Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
MatrixT3;
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
ConstMatrixT1 a_matrix(a, m, k);
|
||||
ConstMatrixT2 b_matrix(b, k, n);
|
||||
MatrixT3 c_matrix(c, m, n);
|
||||
c_matrix.noalias() = a_matrix * b_matrix;
|
||||
}
|
||||
};
|
||||
|
||||
// If we have Apple's Accelerate framework, use their implementation of GEMM to
|
||||
// get a performance boost for float.
|
||||
#if defined(USE_ACCELERATE_GEMM)
|
||||
template <>
|
||||
class FastGemmFunctor<float, float, float> {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
|
||||
const float* b, size_t ldb, float* c, size_t ldc) {
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
|
||||
lda, b, ldb, 0.0f, c, ldc);
|
||||
}
|
||||
};
|
||||
#endif // USE_ACCELERATE_GEMM
|
||||
|
||||
// Used to keep track of persistent memory buffers used within the op.
|
||||
template <class T, size_t size>
|
||||
struct Im2ColBufferResource : public ResourceBase {
|
||||
mutex mu;
|
||||
T data[size];
|
||||
string DebugString() { return "Im2ColBufferResource"; }
|
||||
};
|
||||
|
||||
// Implements convolution as a two stage process, first packing the patches of
|
||||
// the input image into columns (im2col) and then running GEMM to produce the
|
||||
// final result.
|
||||
@ -344,7 +262,6 @@ class Im2ColConvFunctor {
|
||||
errors::InvalidArgument("Im2Col patch too large for buffer"));
|
||||
const size_t patches_per_chunk =
|
||||
max_chunk_size / (filter_value_count * sizeof(T1));
|
||||
|
||||
// Because memory allocation is very expensive on mobile platforms, try to
|
||||
// allocate a persistent buffer that will be kept around between calls. We
|
||||
// use TensorFlow's resource management to ensure that the memory will be
|
||||
|
105
tensorflow/core/kernels/gemm_functors.h
Normal file
105
tensorflow/core/kernels/gemm_functors.h
Normal file
@ -0,0 +1,105 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This is a set of different implementations for the basic matrix by matrix
|
||||
// multiply function, commonly known as GEMM after the BLAS library's naming.
|
||||
// Having a standard interface enables us to swap out implementations on
|
||||
// different platforms, to make sure we're using the optimal version. They are
|
||||
// implemented as C++ template functors, so they're easy to swap into all of the
|
||||
// different kernels that use them.
|
||||
|
||||
#include <string.h>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
|
||||
#include <Accelerate/Accelerate.h>
|
||||
#define USE_ACCELERATE_GEMM
|
||||
#endif // __APPLE__
|
||||
|
||||
// A readable but slow implementation of matrix multiplication, useful for
|
||||
// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
|
||||
// the Im2ColConvFunctor template definition inside the op registration to
|
||||
// enable. Assumes row-major ordering of the values in memory.
|
||||
template <class T1, class T2, class T3>
|
||||
class ReferenceGemmFunctor {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
const size_t a_i_stride = lda;
|
||||
const size_t a_l_stride = 1;
|
||||
const size_t b_j_stride = 1;
|
||||
const size_t b_l_stride = ldb;
|
||||
const size_t c_i_stride = ldc;
|
||||
const size_t c_j_stride = 1;
|
||||
size_t i, j, l;
|
||||
for (j = 0; j < n; j++) {
|
||||
for (i = 0; i < m; i++) {
|
||||
T3 total(0);
|
||||
for (l = 0; l < k; l++) {
|
||||
const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
|
||||
const T1 a_value = a[a_index];
|
||||
const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
|
||||
const T2 b_value = b[b_index];
|
||||
total += (a_value * b_value);
|
||||
}
|
||||
const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
|
||||
c[c_index] = total;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Uses the optimized Eigen library to implement the matrix multiplication
|
||||
// required by the Im2ColConvFunctor class. We supply the two input and one
|
||||
// output types so that the accumulator can potentially be higher-precision than
|
||||
// the inputs, even though we don't currently take advantage of this.
|
||||
template <class T1, class T2, class T3>
|
||||
class FastGemmFunctor {
|
||||
public:
|
||||
// Convenience wrappers for the Eigen matrix types we'll be using.
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT1;
|
||||
typedef Eigen::Map<
|
||||
const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
ConstMatrixT2;
|
||||
typedef Eigen::Map<
|
||||
Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
|
||||
MatrixT3;
|
||||
void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
|
||||
const T2* b, size_t ldb, T3* c, size_t ldc) {
|
||||
ConstMatrixT1 a_matrix(a, m, k);
|
||||
ConstMatrixT2 b_matrix(b, k, n);
|
||||
MatrixT3 c_matrix(c, m, n);
|
||||
c_matrix.noalias() = a_matrix * b_matrix;
|
||||
}
|
||||
};
|
||||
|
||||
// If we have Apple's Accelerate framework, use their implementation of GEMM to
|
||||
// get a performance boost for float.
|
||||
#if defined(USE_ACCELERATE_GEMM)
|
||||
template <>
|
||||
class FastGemmFunctor<float, float, float> {
|
||||
public:
|
||||
void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
|
||||
const float* b, size_t ldb, float* c, size_t ldc) {
|
||||
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
|
||||
lda, b, ldb, 0.0f, c, ldc);
|
||||
}
|
||||
};
|
||||
#endif // USE_ACCELERATE_GEMM
|
@ -49,12 +49,13 @@ struct ImageResizerState {
|
||||
explicit ImageResizerState(bool align_corners)
|
||||
: align_corners_(align_corners) {}
|
||||
|
||||
// ValidateAndCreateOutput checks the bounds on the input tensors
|
||||
// ValidateAndCalculateOutputSize checks the bounds on the input tensors
|
||||
// and requested size, sets up some of the resizing state such as the
|
||||
// height_scale and width_scale, and allocates the output.
|
||||
// height_scale and width_scale, and calculates the output size.
|
||||
// If any of these operations fails, it sets an error status in
|
||||
// the context, which the caller must check.
|
||||
void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
|
||||
void ValidateAndCalculateOutputSize(OpKernelContext* context,
|
||||
const Tensor& input) {
|
||||
OP_REQUIRES(context, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
@ -87,12 +88,18 @@ struct ImageResizerState {
|
||||
OP_REQUIRES(
|
||||
context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
|
||||
errors::InvalidArgument("input image must be of non-zero size"));
|
||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
||||
}
|
||||
|
||||
// Calculates all the required variables, and allocates the output.
|
||||
void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
|
||||
ValidateAndCalculateOutputSize(context, input);
|
||||
if (!context->status().ok()) return;
|
||||
OP_REQUIRES_OK(context, context->allocate_output(
|
||||
0, TensorShape({input.dim_size(0), out_height,
|
||||
out_width, input.dim_size(3)}),
|
||||
&output));
|
||||
height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
|
||||
width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
|
||||
}
|
||||
|
||||
int64 batch_size;
|
||||
|
@ -185,6 +185,7 @@ class OpsTestBase : public ::testing::Test {
|
||||
test::SetOutputAttrs(params_.get(), &attrs);
|
||||
checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
|
||||
params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
|
||||
params_.get()->resource_manager = device_.get()->resource_manager();
|
||||
|
||||
context_.reset(new OpKernelContext(params_.get()));
|
||||
device_->Compute(kernel_.get(), context_.get());
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/util/mirror_pad_mode.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -425,6 +426,46 @@ data_format: Specify the data format of the input and output data. With the
|
||||
[batch, in_channels, in_height, in_width].
|
||||
)doc");
|
||||
|
||||
REGISTER_OP("FusedResizeAndPadConv2D")
|
||||
.Input("input: T")
|
||||
.Input("size: int32")
|
||||
.Input("paddings: int32")
|
||||
.Input("filter: T")
|
||||
.Output("output: T")
|
||||
.Attr("T: {half, float, double}")
|
||||
.Attr("resize_align_corners: bool = false")
|
||||
.Attr(GetMirrorPadModeAttrString())
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Doc(R"doc(
|
||||
Performs a resize and padding as a preprocess during a convolution.
|
||||
|
||||
It's often possible to do spatial transformations more efficiently as part of
|
||||
the packing stage of a convolution, so this op allows for an optimized
|
||||
implementation where these stages are fused together. This prevents the need to
|
||||
write out the intermediate results as whole tensors, reducing memory pressure,
|
||||
and we can get some latency gains by merging the transformation calculations.
|
||||
The data_format attribute for Conv2D isn't supported by this op, and defaults to
|
||||
'NHWC' order.
|
||||
Internally this op uses a single per-graph scratch buffer, which means that it
|
||||
will block if multiple versions are being run in parallel. This is because this
|
||||
operator is primarily an optimization to minimize memory usage.
|
||||
|
||||
input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
|
||||
size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
|
||||
new size for the images.
|
||||
paddings: A two-column matrix specifying the padding sizes. The number of
|
||||
rows must be the same as the rank of `input`.
|
||||
filter: 4-D with shape
|
||||
`[filter_height, filter_width, in_channels, out_channels]`.
|
||||
resize_align_corners: If true, rescale input by (new_height - 1) / (height - 1),
|
||||
which exactly aligns the 4 corners of images and resized images. If false, rescale
|
||||
by new_height / height. Treat similarly the width dimension.
|
||||
strides: 1-D of length 4. The stride of the sliding window for each dimension
|
||||
of `input`. Must be in the same order as the dimension specified with format.
|
||||
padding: The type of padding algorithm to use.
|
||||
)doc");
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
||||
REGISTER_OP("DepthwiseConv2dNative")
|
||||
|
@ -844,6 +844,71 @@ ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape)
|
||||
ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)
|
||||
|
||||
|
||||
@ops.RegisterShape("FusedResizeAndPadConv2D")
|
||||
def _FusedResizeAndPadConv2DShape(op):
|
||||
"""Shape function for FusedResizeAndPadConv2D op."""
|
||||
# The bilinear resize shape calculation.
|
||||
input_shape = op.inputs[0].get_shape().with_rank(4)
|
||||
unused_size_shape = op.inputs[1].get_shape().merge_with([2])
|
||||
size = tensor_util.constant_value(op.inputs[1])
|
||||
if size is not None:
|
||||
height = size[0]
|
||||
width = size[1]
|
||||
else:
|
||||
height = None
|
||||
width = None
|
||||
resized_shape = tensor_shape.TensorShape(
|
||||
[input_shape[0], height, width, input_shape[3]])
|
||||
|
||||
# Calculates the effect of the padding.
|
||||
paddings_shape = op.inputs[2].get_shape().with_rank(2)
|
||||
resized_shape = resized_shape.with_rank(paddings_shape[0].value)
|
||||
paddings_shape = paddings_shape.merge_with(
|
||||
tensor_shape.matrix(resized_shape.ndims, 2))
|
||||
paddings = tensor_util.constant_value(op.inputs[2])
|
||||
if paddings is None:
|
||||
padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims)
|
||||
else:
|
||||
output_dims = []
|
||||
for i, dim in enumerate(resized_shape.dims):
|
||||
if paddings[i, 0] < 0 or paddings[i, 1] < 0:
|
||||
raise ValueError("paddings must be non-negative")
|
||||
output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
|
||||
padded_shape = tensor_shape.TensorShape(output_dims)
|
||||
|
||||
# Finally work out the convolution's effect.
|
||||
filter_shape = op.inputs[3].get_shape().with_rank(4)
|
||||
|
||||
batch_size = padded_shape[0]
|
||||
in_rows = padded_shape[1]
|
||||
in_cols = padded_shape[2]
|
||||
|
||||
filter_rows = filter_shape[0]
|
||||
filter_cols = filter_shape[1]
|
||||
depth_out = filter_shape[3]
|
||||
# Check that the input depths are compatible.
|
||||
padded_shape[3].assert_is_compatible_with(filter_shape[2])
|
||||
|
||||
stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
|
||||
|
||||
if stride_b != 1 or stride_d != 1:
|
||||
raise ValueError("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions.")
|
||||
# TODO(mrry,shlens): Raise an error if the stride would cause
|
||||
# information in the input to be ignored. This will require a change
|
||||
# in the kernel implementation.
|
||||
padding = op.get_attr("padding")
|
||||
out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
|
||||
filter_rows,
|
||||
filter_cols,
|
||||
stride_r,
|
||||
stride_c,
|
||||
padding)
|
||||
|
||||
output_shape = [batch_size, out_rows, out_cols, depth_out]
|
||||
return [tensor_shape.TensorShape(output_shape)]
|
||||
|
||||
|
||||
@ops.RegisterShape("MaxPoolWithArgmax")
|
||||
def _MaxPoolWithArgMaxShape(op):
|
||||
"""Shape function for MaxPoolWithArgmax op."""
|
||||
|
@ -27,6 +27,8 @@ the network is used only for inference. These include:
|
||||
|
||||
- Folding batch normalization ops into the pre-calculated weights.
|
||||
|
||||
- Fusing common operations into unified versions.
|
||||
|
||||
This script takes a frozen GraphDef file (where the weight variables have been
|
||||
converted into constants by the freeze_graph script) and outputs a new GraphDef
|
||||
with the optimizations applied.
|
||||
|
@ -27,6 +27,8 @@ the network is used only for inference. These include:
|
||||
|
||||
- Folding batch normalization ops into the pre-calculated weights.
|
||||
|
||||
- Fusing common operations into unified versions.
|
||||
|
||||
This script takes a frozen GraphDef file (where the weight variables have been
|
||||
converted into constants by the freeze_graph script) and outputs a new GraphDef
|
||||
with the optimizations applied.
|
||||
@ -37,8 +39,8 @@ bazel build tensorflow/python/tools:optimize_for_inference && \
|
||||
bazel-bin/tensorflow/python/tools/optimize_for_inference \
|
||||
--input_graph=some_graph_def.pb \
|
||||
--output_graph=/tmp/optimized_graph.pb \
|
||||
--input_node_names=Mul
|
||||
--output_node_names=softmax
|
||||
--input_names=Mul \
|
||||
--output_names=softmax
|
||||
|
||||
"""
|
||||
|
||||
@ -74,13 +76,42 @@ def optimize_for_inference(input_graph_def, input_node_names,
|
||||
Returns:
|
||||
An optimized version of the input graph.
|
||||
"""
|
||||
stripped_graph_def = strip_unused_lib.strip_unused(input_graph_def,
|
||||
input_node_names,
|
||||
output_node_names,
|
||||
placeholder_type_enum)
|
||||
detrained_graph_def = graph_util.remove_training_nodes(stripped_graph_def)
|
||||
folded_graph_def = fold_batch_norms(detrained_graph_def)
|
||||
return folded_graph_def
|
||||
ensure_graph_is_valid(input_graph_def)
|
||||
optimized_graph_def = input_graph_def
|
||||
optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
|
||||
input_node_names,
|
||||
output_node_names,
|
||||
placeholder_type_enum)
|
||||
optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
|
||||
optimized_graph_def = fold_batch_norms(optimized_graph_def)
|
||||
optimized_graph_def = fuse_resize_and_conv(optimized_graph_def)
|
||||
ensure_graph_is_valid(optimized_graph_def)
|
||||
return optimized_graph_def
|
||||
|
||||
|
||||
def ensure_graph_is_valid(graph_def):
|
||||
"""Makes sure that the graph is internally consistent.
|
||||
|
||||
Checks basic properties of the graph def and raises an exception if there are
|
||||
input references to missing nodes, duplicated names, or other logic errors.
|
||||
|
||||
Args:
|
||||
graph_def: Definition of a graph to be checked.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph is incorrectly constructed.
|
||||
"""
|
||||
node_map = {}
|
||||
for node in graph_def.node:
|
||||
if node.name not in node_map.keys():
|
||||
node_map[node.name] = node
|
||||
else:
|
||||
raise ValueError("Duplicate node names detected for ", node.name)
|
||||
for node in graph_def.node:
|
||||
for input_name in node.input:
|
||||
input_node_name = node_name_from_input(input_name)
|
||||
if input_node_name not in node_map.keys():
|
||||
raise ValueError("Input for ", node.name, " not found: ", input_name)
|
||||
|
||||
|
||||
def node_name_from_input(node_name):
|
||||
@ -161,7 +192,7 @@ def fold_batch_norms(input_graph_def):
|
||||
if node.name not in input_node_map.keys():
|
||||
input_node_map[node.name] = node
|
||||
else:
|
||||
raise ValueError("Duplicate node names detected.")
|
||||
raise ValueError("Duplicate node names detected for ", node.name)
|
||||
|
||||
nodes_to_skip = {}
|
||||
new_ops = []
|
||||
@ -303,3 +334,94 @@ def fold_batch_norms(input_graph_def):
|
||||
|
||||
result_graph_def.node.extend(new_ops)
|
||||
return result_graph_def
|
||||
|
||||
|
||||
def fuse_resize_and_conv(input_graph_def):
|
||||
"""Merges preceding resize and mirror pad ops into a specialized convolution.
|
||||
|
||||
There's a common pattern of enlarging the input to a convolution using a
|
||||
resize operation, and also using MirrorPad to extend the boundaries to that
|
||||
zero edge pixels don't bleed inwards when convolving. This routine looks for
|
||||
that pattern of operations, and fuses them together into a Conv2DWithResizeOp.
|
||||
|
||||
Args:
|
||||
input_graph_def: A GraphDef containing a model.
|
||||
|
||||
Returns:
|
||||
Modified graph with resize and pad ops merged.
|
||||
|
||||
Raises:
|
||||
ValueError: If the graph is badly formed with duplicate node names.
|
||||
"""
|
||||
|
||||
input_node_map = {}
|
||||
for node in input_graph_def.node:
|
||||
if node.name not in input_node_map.keys():
|
||||
input_node_map[node.name] = node
|
||||
else:
|
||||
raise ValueError("Duplicate node names detected for ", node.name)
|
||||
|
||||
nodes_to_skip = {}
|
||||
new_ops = []
|
||||
for node in input_graph_def.node:
|
||||
|
||||
if node.op != "Conv2D":
|
||||
continue
|
||||
conv_op = node
|
||||
|
||||
input_op = node_from_map(input_node_map, conv_op.input[0])
|
||||
if input_op.op == "MirrorPad":
|
||||
mirror_pad_op = input_op
|
||||
resize_op = node_from_map(input_node_map, mirror_pad_op.input[0])
|
||||
else:
|
||||
mirror_pad_op = None
|
||||
resize_op = input_op
|
||||
|
||||
if resize_op.op != "ResizeBilinear":
|
||||
continue
|
||||
|
||||
nodes_to_skip[conv_op.name] = True
|
||||
if mirror_pad_op:
|
||||
nodes_to_skip[mirror_pad_op.name] = True
|
||||
nodes_to_skip[resize_op.name] = True
|
||||
|
||||
fused_conv_op = tf.NodeDef()
|
||||
fused_conv_op.op = "FusedResizeAndPadConv2D"
|
||||
fused_conv_op.name = conv_op.name
|
||||
if mirror_pad_op:
|
||||
mirror_paddings_name = mirror_pad_op.input[1]
|
||||
mirror_paddings_mode = mirror_pad_op.attr["mode"]
|
||||
else:
|
||||
# If there was no MirrorPad op, then create settings that make the padding
|
||||
# stage of the fused operation a no-op.
|
||||
paddings_op = tf.NodeDef()
|
||||
paddings_op.op = "Const"
|
||||
paddings_op.name = conv_op.name + "_dummy_paddings"
|
||||
paddings_op.attr["dtype"].CopyFrom(tf.AttrValue(
|
||||
type=tf.int32.as_datatype_enum))
|
||||
paddings_op.attr["value"].CopyFrom(tf.AttrValue(
|
||||
tensor=tensor_util.make_tensor_proto(
|
||||
[0, 0, 0, 0, 0, 0, 0, 0], tf.int32, [4, 2])))
|
||||
new_ops.extend([paddings_op])
|
||||
mirror_paddings_name = paddings_op.name
|
||||
mirror_paddings_mode = tf.AttrValue(s=b"REFLECT")
|
||||
fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1],
|
||||
mirror_paddings_name, conv_op.input[1]])
|
||||
fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"])
|
||||
fused_conv_op.attr["resize_align_corners"].CopyFrom(
|
||||
resize_op.attr["align_corners"])
|
||||
fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode)
|
||||
fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"])
|
||||
fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"])
|
||||
new_ops.extend([fused_conv_op])
|
||||
|
||||
result_graph_def = tf.GraphDef()
|
||||
for node in input_graph_def.node:
|
||||
if node.name in nodes_to_skip:
|
||||
continue
|
||||
new_node = tf.NodeDef()
|
||||
new_node.CopyFrom(node)
|
||||
result_graph_def.node.extend([new_node])
|
||||
|
||||
result_graph_def.node.extend(new_ops)
|
||||
return result_graph_def
|
||||
|
@ -54,6 +54,7 @@ class OptimizeForInferenceTest(tf.test.TestCase):
|
||||
shape=shape)))
|
||||
|
||||
def testOptimizeForInference(self):
|
||||
unused_constant_name = "unused_constant"
|
||||
unconnected_add_name = "unconnected_add"
|
||||
a_constant_name = "a_constant"
|
||||
b_constant_name = "b_constant"
|
||||
@ -64,9 +65,14 @@ class OptimizeForInferenceTest(tf.test.TestCase):
|
||||
add_name = "add"
|
||||
unused_output_add_name = "unused_output_add"
|
||||
graph_def = tf.GraphDef()
|
||||
unused_constant = self.create_constant_node_def(unused_constant_name,
|
||||
value=0,
|
||||
dtype=tf.float32,
|
||||
shape=[])
|
||||
graph_def.node.extend([unused_constant])
|
||||
unconnected_add_node = self.create_node_def("Add", unconnected_add_name,
|
||||
["no_such_node",
|
||||
"no_such_node"])
|
||||
[unused_constant_name,
|
||||
unused_constant_name])
|
||||
self.set_attr_dtype(unconnected_add_node, "T", tf.float32)
|
||||
graph_def.node.extend([unconnected_add_node])
|
||||
a_constant = self.create_constant_node_def(a_constant_name,
|
||||
@ -160,6 +166,65 @@ class OptimizeForInferenceTest(tf.test.TestCase):
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
|
||||
|
||||
def testFuseResizePadAndConv(self):
|
||||
with self.test_session() as sess:
|
||||
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||
input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
|
||||
dtype=tf.float32)
|
||||
resize_op = tf.image.resize_bilinear(input_op, [12, 4],
|
||||
align_corners=False)
|
||||
pad_op = tf.pad(resize_op, [[0, 0], [1, 1], [2, 2], [0, 0]],
|
||||
mode="REFLECT")
|
||||
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||
weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
|
||||
dtype=tf.float32)
|
||||
tf.nn.conv2d(pad_op, weights_op, [1, 1, 1, 1],
|
||||
padding="VALID", name="output")
|
||||
original_graph_def = sess.graph_def
|
||||
original_result = sess.run(["output:0"])
|
||||
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
|
||||
original_graph_def)
|
||||
|
||||
with self.test_session() as sess:
|
||||
_ = tf.import_graph_def(optimized_graph_def, input_map={},
|
||||
name="optimized")
|
||||
optimized_result = sess.run(["optimized/output:0"])
|
||||
|
||||
self.assertAllClose(original_result, optimized_result)
|
||||
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("Conv2D", node.op)
|
||||
self.assertNotEqual("MirrorPad", node.op)
|
||||
self.assertNotEqual("ResizeBilinear", node.op)
|
||||
|
||||
def testFuseResizeAndConv(self):
|
||||
with self.test_session() as sess:
|
||||
inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
|
||||
input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
|
||||
dtype=tf.float32)
|
||||
resize_op = tf.image.resize_bilinear(input_op, [12, 4],
|
||||
align_corners=False)
|
||||
weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
|
||||
weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
|
||||
dtype=tf.float32)
|
||||
tf.nn.conv2d(resize_op, weights_op, [1, 1, 1, 1],
|
||||
padding="VALID", name="output")
|
||||
original_graph_def = sess.graph_def
|
||||
original_result = sess.run(["output:0"])
|
||||
optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
|
||||
original_graph_def)
|
||||
|
||||
with self.test_session() as sess:
|
||||
_ = tf.import_graph_def(optimized_graph_def, input_map={},
|
||||
name="optimized")
|
||||
optimized_result = sess.run(["optimized/output:0"])
|
||||
|
||||
self.assertAllClose(original_result, optimized_result)
|
||||
|
||||
for node in optimized_graph_def.node:
|
||||
self.assertNotEqual("Conv2D", node.op)
|
||||
self.assertNotEqual("ResizeBilinear", node.op)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tf.test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user