diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index bc9838ec743..0478adab2a9 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -101,6 +101,7 @@ tensorflow/core/kernels/cwise_op_div.cc tensorflow/core/kernels/cwise_op_add.cc tensorflow/core/kernels/ctc_decoder_ops.cc tensorflow/core/kernels/conv_ops_using_gemm.cc +tensorflow/core/kernels/conv_ops_fused.cc tensorflow/core/kernels/conv_ops.cc tensorflow/core/kernels/conv_grad_ops.cc tensorflow/core/kernels/control_flow_ops.cc diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index adbe5545074..6a1967eaf57 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -491,6 +491,27 @@ tf_cc_test( ], ) +tf_cc_test( + name = "conv_ops_test", + size = "small", + deps = [ + ":conv_ops", + ":image", + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "example_parsing_ops_test", size = "large", @@ -1325,6 +1346,7 @@ tf_kernel_library( hdrs = [ "conv_grad_ops.h", "deep_conv2d.h", + "gemm_functors.h", "winograd_transform.h", ], prefix = "conv_ops", @@ -1332,6 +1354,7 @@ tf_kernel_library( ":bounds_check", ":conv_2d", ":conv_3d", + ":image_resizer_state", ":ops_util", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index d09db3dc15f..858be520b07 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_KERNELS_CONV_OPS_H_ #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/util/tensor_format.h" #if GOOGLE_CUDA @@ -38,6 +39,16 @@ class LaunchConv2DOp { TensorFormat data_format); }; +// Used to keep track of persistent memory buffers used within the op. +template <class T, size_t size> +struct Im2ColBufferResource : public ResourceBase { + // This mutex ensures that only a single operation at a time is able to use + // the buffer memory held by this resource. + mutex mu; + T data[size]; + string DebugString() { return "Im2ColBufferResource"; } +}; + #ifdef GOOGLE_CUDA template <typename T> class LaunchConv2DOp<Eigen::GpuDevice, T> { diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc new file mode 100644 index 00000000000..865021405ac --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_fused.cc @@ -0,0 +1,486 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Implements convolution operations with other kernels baked into the +// processing, to optimize latency and memory usage. + +#include <string.h> +#include <map> +#include <vector> +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_ops.h" +#include "tensorflow/core/kernels/gemm_functors.h" +#include "tensorflow/core/kernels/image_resizer_state.h" +#include "tensorflow/core/util/mirror_pad_mode.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +// Combines bilinear resizing and mirror padding into the im2col transformation +// stage of convolution, +template <class T1, class T2, class T3, class TGemmFunctor> +class FusedResizeAndPadConvFunctor { + public: + void operator()(OpKernelContext* context, const Tensor& input, + int input_batches, int resized_height, int resized_width, + int padded_height, int padded_width, int input_depth, + const T2* filter_data, int filter_height, int filter_width, + int filter_count, int stride_rows, int stride_cols, + Padding padding, T3* output_data, int output_height, + int output_width, const ImageResizerState& st, + int top_padding, int bottom_padding, int left_padding, + int right_padding, int pad_offset) { + if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) || + (input_depth <= 0)) { + LOG(WARNING) << "Conv2D was called with bad input dimensions: " + << input_batches << ", " << padded_height << ", " + << padded_width << ", " << input_depth; + return; + } + if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) { + LOG(WARNING) << "Conv2D was called with bad filter dimensions: " + << filter_width << ", " << filter_height << ", " + << filter_count; + return; + } + if ((output_width <= 0) || (output_height <= 0)) { + LOG(WARNING) << "Conv2D was called with bad output width or height: " + << output_width << ", " << output_height; + return; + } + + // These calculations define how the patches will be positioned within the + // input image. The actual definitions are quite complex, and rely on the + // previously-calculated output size. + int filter_left_offset; + int filter_top_offset; + if (padding == VALID) { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - padded_width + 1) / + 2; + filter_top_offset = ((output_height - 1) * stride_rows + filter_height - + padded_height + 1) / + 2; + } else { + filter_left_offset = + ((output_width - 1) * stride_cols + filter_width - padded_width) / 2; + filter_top_offset = + ((output_height - 1) * stride_rows + filter_height - padded_height) / + 2; + } + + // The im2col buffer has # of patches rows, and # of filters cols. + // It's laid out like this, in row major order in memory: + // < filter value count > + // ^ +---------------------+ + // patch | | + // count | | + // v +---------------------+ + // Each patch row contains a filter_width x filter_height patch of the + // input, with the depth channel as the most contiguous in memory, followed + // by the width, then the height. This is the standard memory order in the + // image world if it helps to visualize it. + const int filter_value_count = filter_width * filter_height * input_depth; + + // We don't want to allocate a buffer to hold all the patches if the size is + // going to be extremely large, so break it into chunks if it's bigger than + // a limit. Each chunk will be processed serially, so we can refill the + // buffer for the next chunk and reuse it, keeping maximum memory size down. + // In this case, we've picked 16 megabytes as a reasonable limit. + const size_t max_chunk_size = (16 * 1024 * 1024); + OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size, + errors::InvalidArgument("Im2Col patch too large for buffer")); + const size_t patches_per_chunk = + max_chunk_size / (filter_value_count * sizeof(T1)); + // Because memory allocation is very expensive on mobile platforms, try to + // allocate a persistent buffer that will be kept around between calls. We + // use TensorFlow's resource management to ensure that the memory will be + // released when the session is over. + Im2ColBufferResource<T1, max_chunk_size>* im2col_buffer_resource; + std::function<Status(Im2ColBufferResource<T1, max_chunk_size>**)> creator = + [](Im2ColBufferResource<T1, max_chunk_size>** resource) { + *resource = new Im2ColBufferResource<T1, max_chunk_size>(); + return Status::OK(); + }; + OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate( + "Conv2d", "im2col_buffer", + &im2col_buffer_resource, creator)); + // This means that multiple ops can't be run simultaneously on different + // threads, because we have a single shared resource. The platforms this is + // aimed at have intra-op parallelism as their focus though, so it shouldn't + // be an issue. + mutex_lock lock_buffer(im2col_buffer_resource->mu); + core::ScopedUnref unref_buffer(im2col_buffer_resource); + T1* im2col_buffer = im2col_buffer_resource->data; + + typename TTypes<T1, 4>::ConstTensor input_data = input.tensor<T1, 4>(); + + for (int batch = 0; batch < input_batches; ++batch) { + for (int out_y = 0; out_y < output_height; ++out_y) { + const int in_y_origin = (out_y * stride_rows) - filter_top_offset; + for (int out_x = 0; out_x < output_width; ++out_x) { + const int in_x_origin = (out_x * stride_cols) - filter_left_offset; + const int patch_index = (batch * output_width * output_height) + + (out_y * output_width) + out_x; + const int patch_index_within_chunk = patch_index % patches_per_chunk; + T1* im2col_patch_start = + im2col_buffer + (patch_index_within_chunk * filter_value_count); + for (int filter_y = 0; filter_y < filter_height; ++filter_y) { + const int conv_in_y = in_y_origin + filter_y; + float in_y = (conv_in_y - top_padding); + if (in_y < 0) { + in_y = -(in_y + 1.0f - pad_offset); + } else if (in_y >= resized_height) { + in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset); + } + in_y *= st.height_scale; + const int64 top_y_index = static_cast<int64>(std::floor(in_y)); + const int64 bottom_y_index = std::min( + static_cast<int64>(std::ceil(in_y)), (st.in_height - 1)); + const T1 y_lerp = in_y - top_y_index; + T1* im2col_row_start = + im2col_patch_start + (filter_y * filter_width * input_depth); + for (int filter_x = 0; filter_x < filter_width; ++filter_x) { + const int conv_in_x = in_x_origin + filter_x; + float in_x = (conv_in_x - left_padding); + if (in_x < 0) { + in_x = -(in_x + 1.0f - pad_offset); + } else if (in_x >= resized_width) { + in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset); + } + in_x *= st.width_scale; + const int64 left_x_index = static_cast<int64>(std::floor(in_x)); + const int64 right_x_index = std::min( + static_cast<int64>(std::ceil(in_x)), (st.in_width - 1)); + const T1 x_lerp = in_x - left_x_index; + T1* im2col_row_pixel = + im2col_row_start + (filter_x * input_depth); + for (int in_channel = 0; in_channel < input_depth; ++in_channel) { + T1 in_value; + if ((conv_in_x >= 0) && (conv_in_x < padded_width) && + (conv_in_y >= 0) && (conv_in_y < padded_height)) { + const T1 top_left( + input_data(batch, top_y_index, left_x_index, in_channel)); + const T1 top_right(input_data(batch, top_y_index, + right_x_index, in_channel)); + const T1 bottom_left(input_data(batch, bottom_y_index, + left_x_index, in_channel)); + const T1 bottom_right(input_data(batch, bottom_y_index, + right_x_index, in_channel)); + const T1 top = top_left + (top_right - top_left) * x_lerp; + const T1 bottom = + bottom_left + (bottom_right - bottom_left) * x_lerp; + in_value = top + (bottom - top) * y_lerp; + } else { + in_value = T1(0); + } + im2col_row_pixel[in_channel] = in_value; + } + } + } + const bool is_last_in_chunk = + (patch_index_within_chunk == (patches_per_chunk - 1)); + const bool is_last_overall = + ((batch == (input_batches - 1)) && + (out_y == (output_height - 1)) && (out_x == (output_width - 1))); + if (is_last_in_chunk || is_last_overall) { + // Now we've assembled a set of image patches into a matrix, apply a + // GEMM matrix multiply of the patches as rows, times the filter + // weights in columns, to get partial results in the output matrix. + const int how_many_patches = patch_index_within_chunk + 1; + const int m = how_many_patches; + const int n = filter_count; + const int k = filter_value_count; + const int lda = filter_value_count; + const int ldb = filter_count; + const int ldc = filter_count; + const size_t start_patch_index = + patch_index - (how_many_patches - 1); + T3* chunk_output_data = + output_data + (start_patch_index * filter_count); + TGemmFunctor gemm_functor; + gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb, + chunk_output_data, ldc); + } + } + } + } + } +}; + +} // namespace + +// Implements a version of convolution with bilinear resizing and mirror padding +// included. +template <class T, class TConvFunctor> +class FusedResizeConv2DUsingGemmOp : public OpKernel { + public: + explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr("resize_align_corners", &align_corners_)); + MirrorPadMode mode; + OP_REQUIRES_OK(context, context->GetAttr("mode", &mode)); + + switch (mode) { + case MirrorPadMode::SYMMETRIC: { + offset_ = 0; + break; + } + case MirrorPadMode::REFLECT: { + offset_ = 1; + break; + } + default: + OP_REQUIRES(context, false, + errors::InvalidArgument( + "mode must be either REFLECT or SYMMETRIC.")); + } + OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); + OP_REQUIRES(context, strides_.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N'); + const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C'); + OP_REQUIRES( + context, stride_n == 1 && stride_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_rows, in_cols, in_depth ] + const Tensor& input = context->input(0); + OP_REQUIRES(context, (input.shape().num_elements() > 0), + errors::InvalidArgument("Input tensor can't be empty")); + + ImageResizerState st(align_corners_); + st.ValidateAndCalculateOutputSize(context, input); + if (!context->status().ok()) return; + const TensorShape resized_shape( + {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)}); + + const Tensor& paddings = context->input(2); + + const int dims = resized_shape.dims(); + OP_REQUIRES( + context, TensorShapeUtils::IsMatrix(paddings.shape()) && + paddings.dim_size(1) == 2, + errors::InvalidArgument("paddings must be a matrix with 2 columns: ", + paddings.shape().DebugString())); + const int fixed_dims = + (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1) + ? 1 + : dims; + OP_REQUIRES( + context, fixed_dims == paddings.dim_size(0), + errors::InvalidArgument( + "The first dimension of paddings must be the rank of inputs: ", + fixed_dims, " ", paddings.shape().DebugString(), " ", + resized_shape.DebugString())); + OP_REQUIRES( + context, dims == paddings.dim_size(0), + errors::InvalidArgument( + "The first dimension of paddings must be the rank of inputs: ", + dims, " ", paddings.shape().DebugString(), " ", + resized_shape.DebugString())); + + OP_REQUIRES( + context, dims == 4, + errors::InvalidArgument( + "Fused mirror padding only supports four-dimensional inputs, but ", + dims, " requested")); + + // Compute the shape of the output tensor, and allocate it. + TensorShape padded_shape; + TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>(); + for (int d = 0; d < dims; ++d) { + const int32 before = + paddings_matrix(d, 0); // Pad before existing elements. + const int32 after = + paddings_matrix(d, 1); // Pad after exisitng elements. + OP_REQUIRES(context, before >= 0 && after >= 0, + errors::InvalidArgument("paddings must be non-negative: ", + before, " ", after)); + if (offset_ == 0) { // SYMMETRIC mode. + OP_REQUIRES( + context, before <= resized_shape.dim_size(d) && + after <= resized_shape.dim_size(d), + errors::InvalidArgument("paddings must be no greater " + "than the dimension size: ", + before, ", ", after, " greater than ", + resized_shape.dim_size(d))); + } else if (offset_ == 1) { // REFLECT mode. + OP_REQUIRES( + context, before < resized_shape.dim_size(d) && + after < resized_shape.dim_size(d), + errors::InvalidArgument("paddings must be less than" + " the dimension size: ", + before, ", ", after, " not less than ", + resized_shape.dim_size(d))); + } + padded_shape.AddDim(before + resized_shape.dim_size(d) + after); + } + + OP_REQUIRES( + context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)), + errors::InvalidArgument( + "Fused mirror padding only support spatial padding, not batches: ", + paddings.DebugString())); + OP_REQUIRES( + context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)), + errors::InvalidArgument( + "Fused mirror padding only support spatial padding, not channels: ", + paddings.DebugString())); + const int32 top_padding = paddings_matrix(1, 0); + const int32 bottom_padding = paddings_matrix(1, 1); + const int32 left_padding = paddings_matrix(2, 0); + const int32 right_padding = paddings_matrix(2, 1); + + // Input filter is of the following dimensions: + // [ filter_rows, filter_cols, in_depth, out_depth] + const Tensor& filter = context->input(3); + + // For 2D convolution, there should be 4 dimensions. + OP_REQUIRES(context, padded_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + padded_shape.DebugString())); + OP_REQUIRES(context, filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + + // We only check the first three dims, since the depth is accessed as an + // int64 below. + for (int i = 0; i < 3; i++) { + OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), + std::numeric_limits<int>::max()), + errors::InvalidArgument("filter too large")); + } + + // The last dimension for input is in_depth. It must be the same as the + // filter's in_depth. + const int64 in_depth = padded_shape.dim_size(3); + OP_REQUIRES( + context, in_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + in_depth, " vs ", filter.dim_size(2))); + + // The last dimension for filter is out_depth. + const int out_depth = static_cast<int>(filter.dim_size(3)); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 padded_rows_raw = padded_shape.dim_size(1); + OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw, + std::numeric_limits<int>::max()), + errors::InvalidArgument("Input rows too large")); + const int padded_rows = static_cast<int>(padded_rows_raw); + const int filter_rows = static_cast<int>(filter.dim_size(0)); + const int resized_rows = static_cast<int>(resized_shape.dim_size(1)); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 padded_cols_raw = padded_shape.dim_size(2); + OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw, + std::numeric_limits<int>::max()), + errors::InvalidArgument("Input cols too large")); + const int padded_cols = static_cast<int>(padded_cols_raw); + const int filter_cols = static_cast<int>(filter.dim_size(1)); + const int resized_cols = static_cast<int>(resized_shape.dim_size(2)); + + // The first dimension for input is batch. + const int64 batch_raw = padded_shape.dim_size(0); + OP_REQUIRES(context, + FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("batch is too large")); + const int batch = static_cast<int>(batch_raw); + + // For now we take the stride from the second and third dimensions only (we + // do not support striding on the batch or depth dimension). + const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H'); + const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W'); + + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, + GetWindowedOutputSize(padded_rows, filter_rows, stride_rows, + padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(padded_cols, filter_cols, stride_cols, + padding_, &out_cols, &pad_cols)); + TensorShape out_shape = + ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth); + OP_REQUIRES(context, (out_shape.num_elements() > 0), + errors::InvalidArgument("Output tensor can't be empty")); + + // Output tensor is of the following dimensions: + // [ in_batch, out_rows, out_cols, out_depth ] + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + VLOG(2) << "Conv2D: in_depth = " << in_depth + << ", padded_cols = " << padded_cols + << ", filter_cols = " << filter_cols + << ", padded_rows = " << padded_rows + << ", filter_rows = " << filter_rows + << ", stride_rows = " << stride_rows + << ", stride_cols = " << stride_cols + << ", out_depth = " << out_depth; + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + TConvFunctor conv_functor; + conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows, + padded_cols, in_depth, filter.flat<T>().data(), filter_rows, + filter_cols, out_depth, stride_rows, stride_cols, padding_, + output->flat<T>().data(), out_rows, out_cols, st, top_padding, + bottom_padding, left_padding, right_padding, offset_); + } + + private: + std::vector<int32> strides_; + Padding padding_; + bool align_corners_; + int offset_; + + TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp); +}; + +#define REGISTER_FUSED(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("FusedResizeAndPadConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + FusedResizeConv2DUsingGemmOp< \ + T, \ + FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>); + +TF_CALL_float(REGISTER_FUSED); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc new file mode 100644 index 00000000000..228f2d5defa --- /dev/null +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -0,0 +1,240 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { + +class FusedResizePadConvOpTest : public OpsTestBase { + protected: + void HandwrittenConv() { + const int stride = 1; + TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D") + .Input(FakeInput(DT_FLOAT)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_INT32)) + .Input(FakeInput(DT_FLOAT)) + .Attr("T", DT_FLOAT) + .Attr("resize_align_corners", false) + .Attr("mode", "REFLECT") + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + Tensor image(DT_FLOAT, + {image_batch_count, image_height, image_width, depth}); + test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + const int filter_size = 3; + const int filter_count = 1; + Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count}); + test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9}); + + const int resized_width = image_width; + const int resized_height = image_height; + + const int top_padding = 0; + const int bottom_padding = 0; + const int left_padding = 0; + const int right_padding = 0; + + AddInputFromArray<float>(image.shape(), image.flat<float>()); + AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width}); + AddInputFromArray<int32>( + TensorShape({4, 2}), + {0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0}); + AddInputFromArray<float>(filter.shape(), filter.flat<float>()); + TF_ASSERT_OK(RunOpKernel()); + + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + const int expected_width = image_width; + const int expected_height = image_height * filter_count; + Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues<float>( + &expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121}); + const Tensor& output = *GetOutput(0); + test::ExpectTensorNear<float>(expected, output, 1e-5); + } + + void CompareFusedAndSeparate(int input_width, int input_height, + int input_depth, int resize_width, + int resize_height, int y_padding, int x_padding, + int filter_size, int filter_count, + bool resize_align_corners, string pad_mode, + int stride, string padding) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + const size_t input_data_size = input_height * input_width * input_depth; + Tensor input_data(DT_FLOAT, + TensorShape({1, input_height, input_width, input_depth})); + for (int i = 0; i < input_data_size; ++i) { + input_data.flat<float>()(i) = i + 1.0f; + } + Output input = + Const(root.WithOpName("input"), Input::Initializer(input_data)); + + const size_t filter_data_size = + filter_size * filter_size * filter_count * input_depth; + Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size, + input_depth, filter_count})); + for (int i = 0; i < filter_data_size; ++i) { + filter_data.flat<float>()(i) = i + 1.0f; + } + Output filter = + Const(root.WithOpName("filter"), Input::Initializer(filter_data)); + + Output resize_size = + Const(root.WithOpName("resize_size"), {resize_height, resize_width}); + Output resize = + ResizeBilinear(root.WithOpName("resize"), input, resize_size, + ResizeBilinear::AlignCorners(resize_align_corners)); + Output paddings = + Const(root.WithOpName("paddings"), + {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}}); + Output mirror_pad = + MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode); + Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter, + {1, stride, stride, 1}, padding); + + Output fused_conv = FusedResizeAndPadConv2D( + root.WithOpName("fused_conv"), input, resize_size, paddings, filter, + pad_mode, {1, stride, stride, 1}, padding, + FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners)); + + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + std::unique_ptr<tensorflow::Session> session( + tensorflow::NewSession(tensorflow::SessionOptions())); + TF_ASSERT_OK(session->Create(graph)); + + std::vector<Tensor> unfused_tensors; + TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors)); + + std::vector<Tensor> fused_tensors; + TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors)); + + test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5); + } +}; + +TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); } + +TEST_F(FusedResizePadConvOpTest, IdentityComparative) { + CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) { + CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) { + CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) { + CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) { + CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) { + CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) { + CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1, + "VALID"); +} + +TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) { + CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) { + CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) { + CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) { + CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1, + "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) { + CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1, + "SAME"); +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc index c39510a11a2..6da6da846b4 100644 --- a/tensorflow/core/kernels/conv_ops_using_gemm.cc +++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc @@ -56,14 +56,13 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_ops.h" +#include "tensorflow/core/kernels/gemm_functors.h" +#include "tensorflow/core/kernels/image_resizer_state.h" +#include "tensorflow/core/util/mirror_pad_mode.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" -#if defined(__APPLE__) -#include <Accelerate/Accelerate.h> -#define USE_ACCELERATE_GEMM -#endif // __APPLE__ - namespace tensorflow { namespace { @@ -189,87 +188,6 @@ class ReferenceConvFunctor { } }; -// A readable but slow implementation of matrix multiplication, useful for -// debugging and understanding the algorithm. Use instead of FastGemmFunctor in -// the Im2ColConvFunctor template definition inside the op registration to -// enable. Assumes row-major ordering of the values in memory. -template <class T1, class T2, class T3> -class ReferenceGemmFunctor { - public: - void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda, - const T2* b, size_t ldb, T3* c, size_t ldc) { - const size_t a_i_stride = lda; - const size_t a_l_stride = 1; - const size_t b_j_stride = 1; - const size_t b_l_stride = ldb; - const size_t c_i_stride = ldc; - const size_t c_j_stride = 1; - size_t i, j, l; - for (j = 0; j < n; j++) { - for (i = 0; i < m; i++) { - T3 total(0); - for (l = 0; l < k; l++) { - const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); - const T1 a_value = a[a_index]; - const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); - const T2 b_value = b[b_index]; - total += (a_value * b_value); - } - const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); - c[c_index] = total; - } - } - } -}; - -// Uses the optimized Eigen library to implement the matrix multiplication -// required by the Im2ColConvFunctor class. We supply the two input and one -// output types so that the accumulator can potentially be higher-precision than -// the inputs, even though we don't currently take advantage of this. -template <class T1, class T2, class T3> -class FastGemmFunctor { - public: - // Convenience wrappers for the Eigen matrix types we'll be using. - typedef Eigen::Map< - const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> - ConstMatrixT1; - typedef Eigen::Map< - const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> - ConstMatrixT2; - typedef Eigen::Map< - Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> - MatrixT3; - void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda, - const T2* b, size_t ldb, T3* c, size_t ldc) { - ConstMatrixT1 a_matrix(a, m, k); - ConstMatrixT2 b_matrix(b, k, n); - MatrixT3 c_matrix(c, m, n); - c_matrix.noalias() = a_matrix * b_matrix; - } -}; - -// If we have Apple's Accelerate framework, use their implementation of GEMM to -// get a performance boost for float. -#if defined(USE_ACCELERATE_GEMM) -template <> -class FastGemmFunctor<float, float, float> { - public: - void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda, - const float* b, size_t ldb, float* c, size_t ldc) { - cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a, - lda, b, ldb, 0.0f, c, ldc); - } -}; -#endif // USE_ACCELERATE_GEMM - -// Used to keep track of persistent memory buffers used within the op. -template <class T, size_t size> -struct Im2ColBufferResource : public ResourceBase { - mutex mu; - T data[size]; - string DebugString() { return "Im2ColBufferResource"; } -}; - // Implements convolution as a two stage process, first packing the patches of // the input image into columns (im2col) and then running GEMM to produce the // final result. @@ -344,7 +262,6 @@ class Im2ColConvFunctor { errors::InvalidArgument("Im2Col patch too large for buffer")); const size_t patches_per_chunk = max_chunk_size / (filter_value_count * sizeof(T1)); - // Because memory allocation is very expensive on mobile platforms, try to // allocate a persistent buffer that will be kept around between calls. We // use TensorFlow's resource management to ensure that the memory will be diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h new file mode 100644 index 00000000000..d37008d5cfb --- /dev/null +++ b/tensorflow/core/kernels/gemm_functors.h @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is a set of different implementations for the basic matrix by matrix +// multiply function, commonly known as GEMM after the BLAS library's naming. +// Having a standard interface enables us to swap out implementations on +// different platforms, to make sure we're using the optimal version. They are +// implemented as C++ template functors, so they're easy to swap into all of the +// different kernels that use them. + +#include <string.h> +#include <map> +#include <vector> + +#include "tensorflow/core/framework/tensor.h" + +#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV) +#include <Accelerate/Accelerate.h> +#define USE_ACCELERATE_GEMM +#endif // __APPLE__ + +// A readable but slow implementation of matrix multiplication, useful for +// debugging and understanding the algorithm. Use instead of FastGemmFunctor in +// the Im2ColConvFunctor template definition inside the op registration to +// enable. Assumes row-major ordering of the values in memory. +template <class T1, class T2, class T3> +class ReferenceGemmFunctor { + public: + void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda, + const T2* b, size_t ldb, T3* c, size_t ldc) { + const size_t a_i_stride = lda; + const size_t a_l_stride = 1; + const size_t b_j_stride = 1; + const size_t b_l_stride = ldb; + const size_t c_i_stride = ldc; + const size_t c_j_stride = 1; + size_t i, j, l; + for (j = 0; j < n; j++) { + for (i = 0; i < m; i++) { + T3 total(0); + for (l = 0; l < k; l++) { + const size_t a_index = ((i * a_i_stride) + (l * a_l_stride)); + const T1 a_value = a[a_index]; + const size_t b_index = ((j * b_j_stride) + (l * b_l_stride)); + const T2 b_value = b[b_index]; + total += (a_value * b_value); + } + const size_t c_index = ((i * c_i_stride) + (j * c_j_stride)); + c[c_index] = total; + } + } + } +}; + +// Uses the optimized Eigen library to implement the matrix multiplication +// required by the Im2ColConvFunctor class. We supply the two input and one +// output types so that the accumulator can potentially be higher-precision than +// the inputs, even though we don't currently take advantage of this. +template <class T1, class T2, class T3> +class FastGemmFunctor { + public: + // Convenience wrappers for the Eigen matrix types we'll be using. + typedef Eigen::Map< + const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> + ConstMatrixT1; + typedef Eigen::Map< + const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> + ConstMatrixT2; + typedef Eigen::Map< + Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> + MatrixT3; + void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda, + const T2* b, size_t ldb, T3* c, size_t ldc) { + ConstMatrixT1 a_matrix(a, m, k); + ConstMatrixT2 b_matrix(b, k, n); + MatrixT3 c_matrix(c, m, n); + c_matrix.noalias() = a_matrix * b_matrix; + } +}; + +// If we have Apple's Accelerate framework, use their implementation of GEMM to +// get a performance boost for float. +#if defined(USE_ACCELERATE_GEMM) +template <> +class FastGemmFunctor<float, float, float> { + public: + void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda, + const float* b, size_t ldb, float* c, size_t ldc) { + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a, + lda, b, ldb, 0.0f, c, ldc); + } +}; +#endif // USE_ACCELERATE_GEMM diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h index a7acb5e649b..8870937422a 100644 --- a/tensorflow/core/kernels/image_resizer_state.h +++ b/tensorflow/core/kernels/image_resizer_state.h @@ -49,12 +49,13 @@ struct ImageResizerState { explicit ImageResizerState(bool align_corners) : align_corners_(align_corners) {} - // ValidateAndCreateOutput checks the bounds on the input tensors + // ValidateAndCalculateOutputSize checks the bounds on the input tensors // and requested size, sets up some of the resizing state such as the - // height_scale and width_scale, and allocates the output. + // height_scale and width_scale, and calculates the output size. // If any of these operations fails, it sets an error status in // the context, which the caller must check. - void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) { + void ValidateAndCalculateOutputSize(OpKernelContext* context, + const Tensor& input) { OP_REQUIRES(context, input.dims() == 4, errors::InvalidArgument("input must be 4-dimensional", input.shape().DebugString())); @@ -87,12 +88,18 @@ struct ImageResizerState { OP_REQUIRES( context, input.dim_size(1) > 0 && input.dim_size(2) > 0, errors::InvalidArgument("input image must be of non-zero size")); + height_scale = CalculateResizeScale(in_height, out_height, align_corners_); + width_scale = CalculateResizeScale(in_width, out_width, align_corners_); + } + + // Calculates all the required variables, and allocates the output. + void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) { + ValidateAndCalculateOutputSize(context, input); + if (!context->status().ok()) return; OP_REQUIRES_OK(context, context->allocate_output( 0, TensorShape({input.dim_size(0), out_height, out_width, input.dim_size(3)}), &output)); - height_scale = CalculateResizeScale(in_height, out_height, align_corners_); - width_scale = CalculateResizeScale(in_width, out_width, align_corners_); } int64 batch_size; diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h index eae5187896e..3baae914cbf 100644 --- a/tensorflow/core/kernels/ops_testutil.h +++ b/tensorflow/core/kernels/ops_testutil.h @@ -185,6 +185,7 @@ class OpsTestBase : public ::testing::Test { test::SetOutputAttrs(params_.get(), &attrs); checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper; params_.get()->slice_reader_cache = &slice_reader_cache_wrapper; + params_.get()->resource_manager = device_.get()->resource_manager(); context_.reset(new OpKernelContext(params_.get())); device_->Compute(kernel_.get(), context_.get()); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index cc374278e7f..5daaf83133a 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/util/mirror_pad_mode.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -425,6 +426,46 @@ data_format: Specify the data format of the input and output data. With the [batch, in_channels, in_height, in_width]. )doc"); +REGISTER_OP("FusedResizeAndPadConv2D") + .Input("input: T") + .Input("size: int32") + .Input("paddings: int32") + .Input("filter: T") + .Output("output: T") + .Attr("T: {half, float, double}") + .Attr("resize_align_corners: bool = false") + .Attr(GetMirrorPadModeAttrString()) + .Attr("strides: list(int)") + .Attr(GetPaddingAttrString()) + .Doc(R"doc( +Performs a resize and padding as a preprocess during a convolution. + +It's often possible to do spatial transformations more efficiently as part of +the packing stage of a convolution, so this op allows for an optimized +implementation where these stages are fused together. This prevents the need to +write out the intermediate results as whole tensors, reducing memory pressure, +and we can get some latency gains by merging the transformation calculations. +The data_format attribute for Conv2D isn't supported by this op, and defaults to +'NHWC' order. +Internally this op uses a single per-graph scratch buffer, which means that it +will block if multiple versions are being run in parallel. This is because this +operator is primarily an optimization to minimize memory usage. + +input: 4-D with shape `[batch, in_height, in_width, in_channels]`. +size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The + new size for the images. +paddings: A two-column matrix specifying the padding sizes. The number of + rows must be the same as the rank of `input`. +filter: 4-D with shape + `[filter_height, filter_width, in_channels, out_channels]`. +resize_align_corners: If true, rescale input by (new_height - 1) / (height - 1), + which exactly aligns the 4 corners of images and resized images. If false, rescale + by new_height / height. Treat similarly the width dimension. +strides: 1-D of length 4. The stride of the sliding window for each dimension + of `input`. Must be in the same order as the dimension specified with format. +padding: The type of padding algorithm to use. + )doc"); + // -------------------------------------------------------------------------- REGISTER_OP("DepthwiseConv2dNative") diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index e14ca1a5593..7bc3ffe25b3 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -844,6 +844,71 @@ ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape) ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape) +@ops.RegisterShape("FusedResizeAndPadConv2D") +def _FusedResizeAndPadConv2DShape(op): + """Shape function for FusedResizeAndPadConv2D op.""" + # The bilinear resize shape calculation. + input_shape = op.inputs[0].get_shape().with_rank(4) + unused_size_shape = op.inputs[1].get_shape().merge_with([2]) + size = tensor_util.constant_value(op.inputs[1]) + if size is not None: + height = size[0] + width = size[1] + else: + height = None + width = None + resized_shape = tensor_shape.TensorShape( + [input_shape[0], height, width, input_shape[3]]) + + # Calculates the effect of the padding. + paddings_shape = op.inputs[2].get_shape().with_rank(2) + resized_shape = resized_shape.with_rank(paddings_shape[0].value) + paddings_shape = paddings_shape.merge_with( + tensor_shape.matrix(resized_shape.ndims, 2)) + paddings = tensor_util.constant_value(op.inputs[2]) + if paddings is None: + padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims) + else: + output_dims = [] + for i, dim in enumerate(resized_shape.dims): + if paddings[i, 0] < 0 or paddings[i, 1] < 0: + raise ValueError("paddings must be non-negative") + output_dims.append(dim + paddings[i, 0] + paddings[i, 1]) + padded_shape = tensor_shape.TensorShape(output_dims) + + # Finally work out the convolution's effect. + filter_shape = op.inputs[3].get_shape().with_rank(4) + + batch_size = padded_shape[0] + in_rows = padded_shape[1] + in_cols = padded_shape[2] + + filter_rows = filter_shape[0] + filter_cols = filter_shape[1] + depth_out = filter_shape[3] + # Check that the input depths are compatible. + padded_shape[3].assert_is_compatible_with(filter_shape[2]) + + stride_b, stride_r, stride_c, stride_d = op.get_attr("strides") + + if stride_b != 1 or stride_d != 1: + raise ValueError("Current implementation does not yet support " + "strides in the batch and depth dimensions.") + # TODO(mrry,shlens): Raise an error if the stride would cause + # information in the input to be ignored. This will require a change + # in the kernel implementation. + padding = op.get_attr("padding") + out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols, + filter_rows, + filter_cols, + stride_r, + stride_c, + padding) + + output_shape = [batch_size, out_rows, out_cols, depth_out] + return [tensor_shape.TensorShape(output_shape)] + + @ops.RegisterShape("MaxPoolWithArgmax") def _MaxPoolWithArgMaxShape(op): """Shape function for MaxPoolWithArgmax op.""" diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py index a330ff7c508..9c115f53bec 100644 --- a/tensorflow/python/tools/optimize_for_inference.py +++ b/tensorflow/python/tools/optimize_for_inference.py @@ -27,6 +27,8 @@ the network is used only for inference. These include: - Folding batch normalization ops into the pre-calculated weights. + - Fusing common operations into unified versions. + This script takes a frozen GraphDef file (where the weight variables have been converted into constants by the freeze_graph script) and outputs a new GraphDef with the optimizations applied. diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py index 4eb138d97d9..1cb5ba16256 100644 --- a/tensorflow/python/tools/optimize_for_inference_lib.py +++ b/tensorflow/python/tools/optimize_for_inference_lib.py @@ -27,6 +27,8 @@ the network is used only for inference. These include: - Folding batch normalization ops into the pre-calculated weights. + - Fusing common operations into unified versions. + This script takes a frozen GraphDef file (where the weight variables have been converted into constants by the freeze_graph script) and outputs a new GraphDef with the optimizations applied. @@ -37,8 +39,8 @@ bazel build tensorflow/python/tools:optimize_for_inference && \ bazel-bin/tensorflow/python/tools/optimize_for_inference \ --input_graph=some_graph_def.pb \ --output_graph=/tmp/optimized_graph.pb \ ---input_node_names=Mul ---output_node_names=softmax +--input_names=Mul \ +--output_names=softmax """ @@ -74,13 +76,42 @@ def optimize_for_inference(input_graph_def, input_node_names, Returns: An optimized version of the input graph. """ - stripped_graph_def = strip_unused_lib.strip_unused(input_graph_def, - input_node_names, - output_node_names, - placeholder_type_enum) - detrained_graph_def = graph_util.remove_training_nodes(stripped_graph_def) - folded_graph_def = fold_batch_norms(detrained_graph_def) - return folded_graph_def + ensure_graph_is_valid(input_graph_def) + optimized_graph_def = input_graph_def + optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def, + input_node_names, + output_node_names, + placeholder_type_enum) + optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def) + optimized_graph_def = fold_batch_norms(optimized_graph_def) + optimized_graph_def = fuse_resize_and_conv(optimized_graph_def) + ensure_graph_is_valid(optimized_graph_def) + return optimized_graph_def + + +def ensure_graph_is_valid(graph_def): + """Makes sure that the graph is internally consistent. + + Checks basic properties of the graph def and raises an exception if there are + input references to missing nodes, duplicated names, or other logic errors. + + Args: + graph_def: Definition of a graph to be checked. + + Raises: + ValueError: If the graph is incorrectly constructed. + """ + node_map = {} + for node in graph_def.node: + if node.name not in node_map.keys(): + node_map[node.name] = node + else: + raise ValueError("Duplicate node names detected for ", node.name) + for node in graph_def.node: + for input_name in node.input: + input_node_name = node_name_from_input(input_name) + if input_node_name not in node_map.keys(): + raise ValueError("Input for ", node.name, " not found: ", input_name) def node_name_from_input(node_name): @@ -161,7 +192,7 @@ def fold_batch_norms(input_graph_def): if node.name not in input_node_map.keys(): input_node_map[node.name] = node else: - raise ValueError("Duplicate node names detected.") + raise ValueError("Duplicate node names detected for ", node.name) nodes_to_skip = {} new_ops = [] @@ -303,3 +334,94 @@ def fold_batch_norms(input_graph_def): result_graph_def.node.extend(new_ops) return result_graph_def + + +def fuse_resize_and_conv(input_graph_def): + """Merges preceding resize and mirror pad ops into a specialized convolution. + + There's a common pattern of enlarging the input to a convolution using a + resize operation, and also using MirrorPad to extend the boundaries to that + zero edge pixels don't bleed inwards when convolving. This routine looks for + that pattern of operations, and fuses them together into a Conv2DWithResizeOp. + + Args: + input_graph_def: A GraphDef containing a model. + + Returns: + Modified graph with resize and pad ops merged. + + Raises: + ValueError: If the graph is badly formed with duplicate node names. + """ + + input_node_map = {} + for node in input_graph_def.node: + if node.name not in input_node_map.keys(): + input_node_map[node.name] = node + else: + raise ValueError("Duplicate node names detected for ", node.name) + + nodes_to_skip = {} + new_ops = [] + for node in input_graph_def.node: + + if node.op != "Conv2D": + continue + conv_op = node + + input_op = node_from_map(input_node_map, conv_op.input[0]) + if input_op.op == "MirrorPad": + mirror_pad_op = input_op + resize_op = node_from_map(input_node_map, mirror_pad_op.input[0]) + else: + mirror_pad_op = None + resize_op = input_op + + if resize_op.op != "ResizeBilinear": + continue + + nodes_to_skip[conv_op.name] = True + if mirror_pad_op: + nodes_to_skip[mirror_pad_op.name] = True + nodes_to_skip[resize_op.name] = True + + fused_conv_op = tf.NodeDef() + fused_conv_op.op = "FusedResizeAndPadConv2D" + fused_conv_op.name = conv_op.name + if mirror_pad_op: + mirror_paddings_name = mirror_pad_op.input[1] + mirror_paddings_mode = mirror_pad_op.attr["mode"] + else: + # If there was no MirrorPad op, then create settings that make the padding + # stage of the fused operation a no-op. + paddings_op = tf.NodeDef() + paddings_op.op = "Const" + paddings_op.name = conv_op.name + "_dummy_paddings" + paddings_op.attr["dtype"].CopyFrom(tf.AttrValue( + type=tf.int32.as_datatype_enum)) + paddings_op.attr["value"].CopyFrom(tf.AttrValue( + tensor=tensor_util.make_tensor_proto( + [0, 0, 0, 0, 0, 0, 0, 0], tf.int32, [4, 2]))) + new_ops.extend([paddings_op]) + mirror_paddings_name = paddings_op.name + mirror_paddings_mode = tf.AttrValue(s=b"REFLECT") + fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1], + mirror_paddings_name, conv_op.input[1]]) + fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"]) + fused_conv_op.attr["resize_align_corners"].CopyFrom( + resize_op.attr["align_corners"]) + fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode) + fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"]) + fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"]) + new_ops.extend([fused_conv_op]) + + result_graph_def = tf.GraphDef() + for node in input_graph_def.node: + if node.name in nodes_to_skip: + continue + new_node = tf.NodeDef() + new_node.CopyFrom(node) + result_graph_def.node.extend([new_node]) + + result_graph_def.node.extend(new_ops) + return result_graph_def diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py index 61644fe9c91..d92d7ab8c7d 100644 --- a/tensorflow/python/tools/optimize_for_inference_test.py +++ b/tensorflow/python/tools/optimize_for_inference_test.py @@ -54,6 +54,7 @@ class OptimizeForInferenceTest(tf.test.TestCase): shape=shape))) def testOptimizeForInference(self): + unused_constant_name = "unused_constant" unconnected_add_name = "unconnected_add" a_constant_name = "a_constant" b_constant_name = "b_constant" @@ -64,9 +65,14 @@ class OptimizeForInferenceTest(tf.test.TestCase): add_name = "add" unused_output_add_name = "unused_output_add" graph_def = tf.GraphDef() + unused_constant = self.create_constant_node_def(unused_constant_name, + value=0, + dtype=tf.float32, + shape=[]) + graph_def.node.extend([unused_constant]) unconnected_add_node = self.create_node_def("Add", unconnected_add_name, - ["no_such_node", - "no_such_node"]) + [unused_constant_name, + unused_constant_name]) self.set_attr_dtype(unconnected_add_node, "T", tf.float32) graph_def.node.extend([unconnected_add_node]) a_constant = self.create_constant_node_def(a_constant_name, @@ -160,6 +166,65 @@ class OptimizeForInferenceTest(tf.test.TestCase): for node in optimized_graph_def.node: self.assertNotEqual("BatchNormWithGlobalNormalization", node.op) + def testFuseResizePadAndConv(self): + with self.test_session() as sess: + inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] + input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2], + dtype=tf.float32) + resize_op = tf.image.resize_bilinear(input_op, [12, 4], + align_corners=False) + pad_op = tf.pad(resize_op, [[0, 0], [1, 1], [2, 2], [0, 0]], + mode="REFLECT") + weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] + weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2], + dtype=tf.float32) + tf.nn.conv2d(pad_op, weights_op, [1, 1, 1, 1], + padding="VALID", name="output") + original_graph_def = sess.graph_def + original_result = sess.run(["output:0"]) + optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv( + original_graph_def) + + with self.test_session() as sess: + _ = tf.import_graph_def(optimized_graph_def, input_map={}, + name="optimized") + optimized_result = sess.run(["optimized/output:0"]) + + self.assertAllClose(original_result, optimized_result) + + for node in optimized_graph_def.node: + self.assertNotEqual("Conv2D", node.op) + self.assertNotEqual("MirrorPad", node.op) + self.assertNotEqual("ResizeBilinear", node.op) + + def testFuseResizeAndConv(self): + with self.test_session() as sess: + inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6] + input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2], + dtype=tf.float32) + resize_op = tf.image.resize_bilinear(input_op, [12, 4], + align_corners=False) + weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4] + weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2], + dtype=tf.float32) + tf.nn.conv2d(resize_op, weights_op, [1, 1, 1, 1], + padding="VALID", name="output") + original_graph_def = sess.graph_def + original_result = sess.run(["output:0"]) + optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv( + original_graph_def) + + with self.test_session() as sess: + _ = tf.import_graph_def(optimized_graph_def, input_map={}, + name="optimized") + optimized_result = sess.run(["optimized/output:0"]) + + self.assertAllClose(original_result, optimized_result) + + for node in optimized_graph_def.node: + self.assertNotEqual("Conv2D", node.op) + self.assertNotEqual("ResizeBilinear", node.op) + if __name__ == "__main__": tf.test.main()