diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index bc9838ec743..0478adab2a9 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -101,6 +101,7 @@ tensorflow/core/kernels/cwise_op_div.cc
 tensorflow/core/kernels/cwise_op_add.cc
 tensorflow/core/kernels/ctc_decoder_ops.cc
 tensorflow/core/kernels/conv_ops_using_gemm.cc
+tensorflow/core/kernels/conv_ops_fused.cc
 tensorflow/core/kernels/conv_ops.cc
 tensorflow/core/kernels/conv_grad_ops.cc
 tensorflow/core/kernels/control_flow_ops.cc
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index adbe5545074..6a1967eaf57 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -491,6 +491,27 @@ tf_cc_test(
     ],
 )
 
+tf_cc_test(
+    name = "conv_ops_test",
+    size = "small",
+    deps = [
+        ":conv_ops",
+        ":image",
+        ":ops_testutil",
+        ":ops_util",
+        "//tensorflow/cc:cc_ops",
+        "//tensorflow/core:core_cpu",
+        "//tensorflow/core:framework",
+        "//tensorflow/core:framework_internal",
+        "//tensorflow/core:lib",
+        "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core:tensorflow",
+        "//tensorflow/core:test",
+        "//tensorflow/core:test_main",
+        "//tensorflow/core:testlib",
+    ],
+)
+
 tf_cc_test(
     name = "example_parsing_ops_test",
     size = "large",
@@ -1325,6 +1346,7 @@ tf_kernel_library(
     hdrs = [
         "conv_grad_ops.h",
         "deep_conv2d.h",
+        "gemm_functors.h",
         "winograd_transform.h",
     ],
     prefix = "conv_ops",
@@ -1332,6 +1354,7 @@ tf_kernel_library(
         ":bounds_check",
         ":conv_2d",
         ":conv_3d",
+        ":image_resizer_state",
         ":ops_util",
         "//tensorflow/core:core_cpu",
         "//tensorflow/core:framework",
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index d09db3dc15f..858be520b07 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -17,6 +17,7 @@ limitations under the License.
 #define TENSORFLOW_KERNELS_CONV_OPS_H_
 
 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/resource_mgr.h"
 #include "tensorflow/core/util/tensor_format.h"
 
 #if GOOGLE_CUDA
@@ -38,6 +39,16 @@ class LaunchConv2DOp {
               TensorFormat data_format);
 };
 
+// Used to keep track of persistent memory buffers used within the op.
+template <class T, size_t size>
+struct Im2ColBufferResource : public ResourceBase {
+  // This mutex ensures that only a single operation at a time is able to use
+  // the buffer memory held by this resource.
+  mutex mu;
+  T data[size];
+  string DebugString() { return "Im2ColBufferResource"; }
+};
+
 #ifdef GOOGLE_CUDA
 template <typename T>
 class LaunchConv2DOp<Eigen::GpuDevice, T> {
diff --git a/tensorflow/core/kernels/conv_ops_fused.cc b/tensorflow/core/kernels/conv_ops_fused.cc
new file mode 100644
index 00000000000..865021405ac
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_fused.cc
@@ -0,0 +1,486 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Implements convolution operations with other kernels baked into the
+// processing, to optimize latency and memory usage.
+
+#include <string.h>
+#include <map>
+#include <vector>
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_slice.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_ops.h"
+#include "tensorflow/core/kernels/gemm_functors.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
+#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+namespace {
+
+// Combines bilinear resizing and mirror padding into the im2col transformation
+// stage of convolution,
+template <class T1, class T2, class T3, class TGemmFunctor>
+class FusedResizeAndPadConvFunctor {
+ public:
+  void operator()(OpKernelContext* context, const Tensor& input,
+                  int input_batches, int resized_height, int resized_width,
+                  int padded_height, int padded_width, int input_depth,
+                  const T2* filter_data, int filter_height, int filter_width,
+                  int filter_count, int stride_rows, int stride_cols,
+                  Padding padding, T3* output_data, int output_height,
+                  int output_width, const ImageResizerState& st,
+                  int top_padding, int bottom_padding, int left_padding,
+                  int right_padding, int pad_offset) {
+    if ((input_batches <= 0) || (padded_width <= 0) || (padded_height <= 0) ||
+        (input_depth <= 0)) {
+      LOG(WARNING) << "Conv2D was called with bad input dimensions: "
+                   << input_batches << ", " << padded_height << ", "
+                   << padded_width << ", " << input_depth;
+      return;
+    }
+    if ((filter_width <= 0) || (filter_height <= 0) || (filter_count <= 0)) {
+      LOG(WARNING) << "Conv2D was called with bad filter dimensions: "
+                   << filter_width << ", " << filter_height << ", "
+                   << filter_count;
+      return;
+    }
+    if ((output_width <= 0) || (output_height <= 0)) {
+      LOG(WARNING) << "Conv2D was called with bad output width or height: "
+                   << output_width << ", " << output_height;
+      return;
+    }
+
+    // These calculations define how the patches will be positioned within the
+    // input image. The actual definitions are quite complex, and rely on the
+    // previously-calculated output size.
+    int filter_left_offset;
+    int filter_top_offset;
+    if (padding == VALID) {
+      filter_left_offset =
+          ((output_width - 1) * stride_cols + filter_width - padded_width + 1) /
+          2;
+      filter_top_offset = ((output_height - 1) * stride_rows + filter_height -
+                           padded_height + 1) /
+                          2;
+    } else {
+      filter_left_offset =
+          ((output_width - 1) * stride_cols + filter_width - padded_width) / 2;
+      filter_top_offset =
+          ((output_height - 1) * stride_rows + filter_height - padded_height) /
+          2;
+    }
+
+    // The im2col buffer has # of patches rows, and # of filters cols.
+    // It's laid out like this, in row major order in memory:
+    //        < filter value count >
+    //   ^   +---------------------+
+    // patch |                     |
+    // count |                     |
+    //   v   +---------------------+
+    // Each patch row contains a filter_width x filter_height patch of the
+    // input, with the depth channel as the most contiguous in memory, followed
+    // by the width, then the height. This is the standard memory order in the
+    // image world if it helps to visualize it.
+    const int filter_value_count = filter_width * filter_height * input_depth;
+
+    // We don't want to allocate a buffer to hold all the patches if the size is
+    // going to be extremely large, so break it into chunks if it's bigger than
+    // a limit. Each chunk will be processed serially, so we can refill the
+    // buffer for the next chunk and reuse it, keeping maximum memory size down.
+    // In this case, we've picked 16 megabytes as a reasonable limit.
+    const size_t max_chunk_size = (16 * 1024 * 1024);
+    OP_REQUIRES(context, (filter_value_count * sizeof(T1)) <= max_chunk_size,
+                errors::InvalidArgument("Im2Col patch too large for buffer"));
+    const size_t patches_per_chunk =
+        max_chunk_size / (filter_value_count * sizeof(T1));
+    // Because memory allocation is very expensive on mobile platforms, try to
+    // allocate a persistent buffer that will be kept around between calls. We
+    // use TensorFlow's resource management to ensure that the memory will be
+    // released when the session is over.
+    Im2ColBufferResource<T1, max_chunk_size>* im2col_buffer_resource;
+    std::function<Status(Im2ColBufferResource<T1, max_chunk_size>**)> creator =
+        [](Im2ColBufferResource<T1, max_chunk_size>** resource) {
+          *resource = new Im2ColBufferResource<T1, max_chunk_size>();
+          return Status::OK();
+        };
+    OP_REQUIRES_OK(context, context->resource_manager()->LookupOrCreate(
+                                "Conv2d", "im2col_buffer",
+                                &im2col_buffer_resource, creator));
+    // This means that multiple ops can't be run simultaneously on different
+    // threads, because we have a single shared resource. The platforms this is
+    // aimed at have intra-op parallelism as their focus though, so it shouldn't
+    // be an issue.
+    mutex_lock lock_buffer(im2col_buffer_resource->mu);
+    core::ScopedUnref unref_buffer(im2col_buffer_resource);
+    T1* im2col_buffer = im2col_buffer_resource->data;
+
+    typename TTypes<T1, 4>::ConstTensor input_data = input.tensor<T1, 4>();
+
+    for (int batch = 0; batch < input_batches; ++batch) {
+      for (int out_y = 0; out_y < output_height; ++out_y) {
+        const int in_y_origin = (out_y * stride_rows) - filter_top_offset;
+        for (int out_x = 0; out_x < output_width; ++out_x) {
+          const int in_x_origin = (out_x * stride_cols) - filter_left_offset;
+          const int patch_index = (batch * output_width * output_height) +
+                                  (out_y * output_width) + out_x;
+          const int patch_index_within_chunk = patch_index % patches_per_chunk;
+          T1* im2col_patch_start =
+              im2col_buffer + (patch_index_within_chunk * filter_value_count);
+          for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
+            const int conv_in_y = in_y_origin + filter_y;
+            float in_y = (conv_in_y - top_padding);
+            if (in_y < 0) {
+              in_y = -(in_y + 1.0f - pad_offset);
+            } else if (in_y >= resized_height) {
+              in_y = (resized_height * 2.0f) - (in_y + 1.0f + pad_offset);
+            }
+            in_y *= st.height_scale;
+            const int64 top_y_index = static_cast<int64>(std::floor(in_y));
+            const int64 bottom_y_index = std::min(
+                static_cast<int64>(std::ceil(in_y)), (st.in_height - 1));
+            const T1 y_lerp = in_y - top_y_index;
+            T1* im2col_row_start =
+                im2col_patch_start + (filter_y * filter_width * input_depth);
+            for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
+              const int conv_in_x = in_x_origin + filter_x;
+              float in_x = (conv_in_x - left_padding);
+              if (in_x < 0) {
+                in_x = -(in_x + 1.0f - pad_offset);
+              } else if (in_x >= resized_width) {
+                in_x = (resized_width * 2.0f) - (in_x + 1.0f + pad_offset);
+              }
+              in_x *= st.width_scale;
+              const int64 left_x_index = static_cast<int64>(std::floor(in_x));
+              const int64 right_x_index = std::min(
+                  static_cast<int64>(std::ceil(in_x)), (st.in_width - 1));
+              const T1 x_lerp = in_x - left_x_index;
+              T1* im2col_row_pixel =
+                  im2col_row_start + (filter_x * input_depth);
+              for (int in_channel = 0; in_channel < input_depth; ++in_channel) {
+                T1 in_value;
+                if ((conv_in_x >= 0) && (conv_in_x < padded_width) &&
+                    (conv_in_y >= 0) && (conv_in_y < padded_height)) {
+                  const T1 top_left(
+                      input_data(batch, top_y_index, left_x_index, in_channel));
+                  const T1 top_right(input_data(batch, top_y_index,
+                                                right_x_index, in_channel));
+                  const T1 bottom_left(input_data(batch, bottom_y_index,
+                                                  left_x_index, in_channel));
+                  const T1 bottom_right(input_data(batch, bottom_y_index,
+                                                   right_x_index, in_channel));
+                  const T1 top = top_left + (top_right - top_left) * x_lerp;
+                  const T1 bottom =
+                      bottom_left + (bottom_right - bottom_left) * x_lerp;
+                  in_value = top + (bottom - top) * y_lerp;
+                } else {
+                  in_value = T1(0);
+                }
+                im2col_row_pixel[in_channel] = in_value;
+              }
+            }
+          }
+          const bool is_last_in_chunk =
+              (patch_index_within_chunk == (patches_per_chunk - 1));
+          const bool is_last_overall =
+              ((batch == (input_batches - 1)) &&
+               (out_y == (output_height - 1)) && (out_x == (output_width - 1)));
+          if (is_last_in_chunk || is_last_overall) {
+            // Now we've assembled a set of image patches into a matrix, apply a
+            // GEMM matrix multiply of the patches as rows, times the filter
+            // weights in columns, to get partial results in the output matrix.
+            const int how_many_patches = patch_index_within_chunk + 1;
+            const int m = how_many_patches;
+            const int n = filter_count;
+            const int k = filter_value_count;
+            const int lda = filter_value_count;
+            const int ldb = filter_count;
+            const int ldc = filter_count;
+            const size_t start_patch_index =
+                patch_index - (how_many_patches - 1);
+            T3* chunk_output_data =
+                output_data + (start_patch_index * filter_count);
+            TGemmFunctor gemm_functor;
+            gemm_functor(m, n, k, im2col_buffer, lda, filter_data, ldb,
+                         chunk_output_data, ldc);
+          }
+        }
+      }
+    }
+  }
+};
+
+}  // namespace
+
+// Implements a version of convolution with bilinear resizing and mirror padding
+// included.
+template <class T, class TConvFunctor>
+class FusedResizeConv2DUsingGemmOp : public OpKernel {
+ public:
+  explicit FusedResizeConv2DUsingGemmOp(OpKernelConstruction* context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("resize_align_corners", &align_corners_));
+    MirrorPadMode mode;
+    OP_REQUIRES_OK(context, context->GetAttr("mode", &mode));
+
+    switch (mode) {
+      case MirrorPadMode::SYMMETRIC: {
+        offset_ = 0;
+        break;
+      }
+      case MirrorPadMode::REFLECT: {
+        offset_ = 1;
+        break;
+      }
+      default:
+        OP_REQUIRES(context, false,
+                    errors::InvalidArgument(
+                        "mode must be either REFLECT or SYMMETRIC."));
+    }
+    OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
+    OP_REQUIRES(context, strides_.size() == 4,
+                errors::InvalidArgument("Sliding window strides field must "
+                                        "specify 4 dimensions"));
+    const int64 stride_n = GetTensorDim(strides_, FORMAT_NHWC, 'N');
+    const int64 stride_c = GetTensorDim(strides_, FORMAT_NHWC, 'C');
+    OP_REQUIRES(
+        context, stride_n == 1 && stride_c == 1,
+        errors::InvalidArgument("Current implementation does not yet support "
+                                "strides in the batch and depth dimensions."));
+    OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+  }
+
+  void Compute(OpKernelContext* context) override {
+    // Input tensor is of the following dimensions:
+    // [ batch, in_rows, in_cols, in_depth ]
+    const Tensor& input = context->input(0);
+    OP_REQUIRES(context, (input.shape().num_elements() > 0),
+                errors::InvalidArgument("Input tensor can't be empty"));
+
+    ImageResizerState st(align_corners_);
+    st.ValidateAndCalculateOutputSize(context, input);
+    if (!context->status().ok()) return;
+    const TensorShape resized_shape(
+        {input.dim_size(0), st.out_height, st.out_width, input.dim_size(3)});
+
+    const Tensor& paddings = context->input(2);
+
+    const int dims = resized_shape.dims();
+    OP_REQUIRES(
+        context, TensorShapeUtils::IsMatrix(paddings.shape()) &&
+                     paddings.dim_size(1) == 2,
+        errors::InvalidArgument("paddings must be a matrix with 2 columns: ",
+                                paddings.shape().DebugString()));
+    const int fixed_dims =
+        (allow_legacy_scalars() && dims == 0 && paddings.dim_size(0) == 1)
+            ? 1
+            : dims;
+    OP_REQUIRES(
+        context, fixed_dims == paddings.dim_size(0),
+        errors::InvalidArgument(
+            "The first dimension of paddings must be the rank of inputs: ",
+            fixed_dims, " ", paddings.shape().DebugString(), " ",
+            resized_shape.DebugString()));
+    OP_REQUIRES(
+        context, dims == paddings.dim_size(0),
+        errors::InvalidArgument(
+            "The first dimension of paddings must be the rank of inputs: ",
+            dims, " ", paddings.shape().DebugString(), " ",
+            resized_shape.DebugString()));
+
+    OP_REQUIRES(
+        context, dims == 4,
+        errors::InvalidArgument(
+            "Fused mirror padding only supports four-dimensional inputs, but ",
+            dims, " requested"));
+
+    // Compute the shape of the output tensor, and allocate it.
+    TensorShape padded_shape;
+    TTypes<int32>::ConstMatrix paddings_matrix = paddings.matrix<int32>();
+    for (int d = 0; d < dims; ++d) {
+      const int32 before =
+          paddings_matrix(d, 0);  // Pad before existing elements.
+      const int32 after =
+          paddings_matrix(d, 1);  // Pad after exisitng elements.
+      OP_REQUIRES(context, before >= 0 && after >= 0,
+                  errors::InvalidArgument("paddings must be non-negative: ",
+                                          before, " ", after));
+      if (offset_ == 0) {  // SYMMETRIC mode.
+        OP_REQUIRES(
+            context, before <= resized_shape.dim_size(d) &&
+                         after <= resized_shape.dim_size(d),
+            errors::InvalidArgument("paddings must be no greater "
+                                    "than the dimension size: ",
+                                    before, ", ", after, " greater than ",
+                                    resized_shape.dim_size(d)));
+      } else if (offset_ == 1) {  // REFLECT mode.
+        OP_REQUIRES(
+            context, before < resized_shape.dim_size(d) &&
+                         after < resized_shape.dim_size(d),
+            errors::InvalidArgument("paddings must be less than"
+                                    " the dimension size: ",
+                                    before, ", ", after, " not less than ",
+                                    resized_shape.dim_size(d)));
+      }
+      padded_shape.AddDim(before + resized_shape.dim_size(d) + after);
+    }
+
+    OP_REQUIRES(
+        context, ((paddings_matrix(0, 0) == 0) && (paddings_matrix(0, 1) == 0)),
+        errors::InvalidArgument(
+            "Fused mirror padding only support spatial padding, not batches: ",
+            paddings.DebugString()));
+    OP_REQUIRES(
+        context, ((paddings_matrix(3, 0) == 0) && (paddings_matrix(3, 1) == 0)),
+        errors::InvalidArgument(
+            "Fused mirror padding only support spatial padding, not channels: ",
+            paddings.DebugString()));
+    const int32 top_padding = paddings_matrix(1, 0);
+    const int32 bottom_padding = paddings_matrix(1, 1);
+    const int32 left_padding = paddings_matrix(2, 0);
+    const int32 right_padding = paddings_matrix(2, 1);
+
+    // Input filter is of the following dimensions:
+    // [ filter_rows, filter_cols, in_depth, out_depth]
+    const Tensor& filter = context->input(3);
+
+    // For 2D convolution, there should be 4 dimensions.
+    OP_REQUIRES(context, padded_shape.dims() == 4,
+                errors::InvalidArgument("input must be 4-dimensional",
+                                        padded_shape.DebugString()));
+    OP_REQUIRES(context, filter.dims() == 4,
+                errors::InvalidArgument("filter must be 4-dimensional: ",
+                                        filter.shape().DebugString()));
+
+    // We only check the first three dims, since the depth is accessed as an
+    // int64 below.
+    for (int i = 0; i < 3; i++) {
+      OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i),
+                                           std::numeric_limits<int>::max()),
+                  errors::InvalidArgument("filter too large"));
+    }
+
+    // The last dimension for input is in_depth. It must be the same as the
+    // filter's in_depth.
+    const int64 in_depth = padded_shape.dim_size(3);
+    OP_REQUIRES(
+        context, in_depth == filter.dim_size(2),
+        errors::InvalidArgument("input and filter must have the same depth: ",
+                                in_depth, " vs ", filter.dim_size(2)));
+
+    // The last dimension for filter is out_depth.
+    const int out_depth = static_cast<int>(filter.dim_size(3));
+
+    // The second dimension for input is rows/height.
+    // The first dimension for filter is rows/height.
+    const int64 padded_rows_raw = padded_shape.dim_size(1);
+    OP_REQUIRES(context, FastBoundsCheck(padded_rows_raw,
+                                         std::numeric_limits<int>::max()),
+                errors::InvalidArgument("Input rows too large"));
+    const int padded_rows = static_cast<int>(padded_rows_raw);
+    const int filter_rows = static_cast<int>(filter.dim_size(0));
+    const int resized_rows = static_cast<int>(resized_shape.dim_size(1));
+
+    // The third dimension for input is columns/width.
+    // The second dimension for filter is columns/width.
+    const int64 padded_cols_raw = padded_shape.dim_size(2);
+    OP_REQUIRES(context, FastBoundsCheck(padded_cols_raw,
+                                         std::numeric_limits<int>::max()),
+                errors::InvalidArgument("Input cols too large"));
+    const int padded_cols = static_cast<int>(padded_cols_raw);
+    const int filter_cols = static_cast<int>(filter.dim_size(1));
+    const int resized_cols = static_cast<int>(resized_shape.dim_size(2));
+
+    // The first dimension for input is batch.
+    const int64 batch_raw = padded_shape.dim_size(0);
+    OP_REQUIRES(context,
+                FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+                errors::InvalidArgument("batch is too large"));
+    const int batch = static_cast<int>(batch_raw);
+
+    // For now we take the stride from the second and third dimensions only (we
+    // do not support striding on the batch or depth dimension).
+    const int stride_rows = GetTensorDim(strides_, FORMAT_NHWC, 'H');
+    const int stride_cols = GetTensorDim(strides_, FORMAT_NHWC, 'W');
+
+    int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+    OP_REQUIRES_OK(context,
+                   GetWindowedOutputSize(padded_rows, filter_rows, stride_rows,
+                                         padding_, &out_rows, &pad_rows));
+    OP_REQUIRES_OK(context,
+                   GetWindowedOutputSize(padded_cols, filter_cols, stride_cols,
+                                         padding_, &out_cols, &pad_cols));
+    TensorShape out_shape =
+        ShapeFromFormat(FORMAT_NHWC, batch, out_rows, out_cols, out_depth);
+    OP_REQUIRES(context, (out_shape.num_elements() > 0),
+                errors::InvalidArgument("Output tensor can't be empty"));
+
+    // Output tensor is of the following dimensions:
+    // [ in_batch, out_rows, out_cols, out_depth ]
+    Tensor* output = nullptr;
+    OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+    VLOG(2) << "Conv2D: in_depth = " << in_depth
+            << ", padded_cols = " << padded_cols
+            << ", filter_cols = " << filter_cols
+            << ", padded_rows = " << padded_rows
+            << ", filter_rows = " << filter_rows
+            << ", stride_rows = " << stride_rows
+            << ", stride_cols = " << stride_cols
+            << ", out_depth = " << out_depth;
+
+    // If there is nothing to compute, return.
+    if (out_shape.num_elements() == 0) {
+      return;
+    }
+    TConvFunctor conv_functor;
+    conv_functor(context, input, batch, resized_rows, resized_cols, padded_rows,
+                 padded_cols, in_depth, filter.flat<T>().data(), filter_rows,
+                 filter_cols, out_depth, stride_rows, stride_cols, padding_,
+                 output->flat<T>().data(), out_rows, out_cols, st, top_padding,
+                 bottom_padding, left_padding, right_padding, offset_);
+  }
+
+ private:
+  std::vector<int32> strides_;
+  Padding padding_;
+  bool align_corners_;
+  int offset_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(FusedResizeConv2DUsingGemmOp);
+};
+
+#define REGISTER_FUSED(T)             \
+  REGISTER_KERNEL_BUILDER(            \
+      Name("FusedResizeAndPadConv2D") \
+          .Device(DEVICE_CPU)         \
+          .TypeConstraint<T>("T"),    \
+      FusedResizeConv2DUsingGemmOp<   \
+          T,                          \
+          FusedResizeAndPadConvFunctor<T, T, T, FastGemmFunctor<T, T, T>>>);
+
+TF_CALL_float(REGISTER_FUSED);
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
new file mode 100644
index 00000000000..228f2d5defa
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -0,0 +1,240 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/image_ops.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+
+class FusedResizePadConvOpTest : public OpsTestBase {
+ protected:
+  void HandwrittenConv() {
+    const int stride = 1;
+    TF_EXPECT_OK(NodeDefBuilder("fused_resize_op", "FusedResizeAndPadConv2D")
+                     .Input(FakeInput(DT_FLOAT))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_INT32))
+                     .Input(FakeInput(DT_FLOAT))
+                     .Attr("T", DT_FLOAT)
+                     .Attr("resize_align_corners", false)
+                     .Attr("mode", "REFLECT")
+                     .Attr("strides", {1, stride, stride, 1})
+                     .Attr("padding", "SAME")
+                     .Finalize(node_def()));
+    TF_EXPECT_OK(InitOp());
+    const int depth = 1;
+    const int image_width = 4;
+    const int image_height = 3;
+    const int image_batch_count = 1;
+    // The image matrix is:
+    // |  1 |  2 |  3 |  4 |
+    // |  5 |  6 |  7 |  8 |
+    // |  9 | 10 | 11 | 12 |
+    Tensor image(DT_FLOAT,
+                 {image_batch_count, image_height, image_width, depth});
+    test::FillValues<float>(&image, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+
+    // The filter matrix is:
+    // | 1 | 4 | 7 |
+    // | 2 | 5 | 8 |
+    // | 3 | 6 | 9 |
+    const int filter_size = 3;
+    const int filter_count = 1;
+    Tensor filter(DT_FLOAT, {filter_size, filter_size, depth, filter_count});
+    test::FillValues<float>(&filter, {1, 4, 7, 2, 5, 8, 3, 6, 9});
+
+    const int resized_width = image_width;
+    const int resized_height = image_height;
+
+    const int top_padding = 0;
+    const int bottom_padding = 0;
+    const int left_padding = 0;
+    const int right_padding = 0;
+
+    AddInputFromArray<float>(image.shape(), image.flat<float>());
+    AddInputFromArray<int32>(TensorShape({2}), {resized_height, resized_width});
+    AddInputFromArray<int32>(
+        TensorShape({4, 2}),
+        {0, 0, top_padding, bottom_padding, left_padding, right_padding, 0, 0});
+    AddInputFromArray<float>(filter.shape(), filter.flat<float>());
+    TF_ASSERT_OK(RunOpKernel());
+
+    // We're sliding the 3x3 filter across the 3x4 image, with accesses outside
+    // the input set to zero because we're using the 'SAME' padding mode.
+    // The calculations behind the expected output are:
+    // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
+    // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
+    // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
+    // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
+    // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
+    // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
+    // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
+    // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
+    // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
+    // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
+    // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
+    // (1*7)+(4*11)+(7*0)+(2*8)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
+    // This means we should end up with this matrix:
+    // |  105  |  150  |  183  |   95  |
+    // |  235  |  312  |  357  |  178  |
+    // |  187  |  234  |  261  |  121  |
+    const int expected_width = image_width;
+    const int expected_height = image_height * filter_count;
+    Tensor expected(DT_FLOAT, TensorShape({image_batch_count, expected_height,
+                                           expected_width, filter_count}));
+    test::FillValues<float>(
+        &expected, {105, 150, 183, 95, 235, 312, 357, 178, 187, 234, 261, 121});
+    const Tensor& output = *GetOutput(0);
+    test::ExpectTensorNear<float>(expected, output, 1e-5);
+  }
+
+  void CompareFusedAndSeparate(int input_width, int input_height,
+                               int input_depth, int resize_width,
+                               int resize_height, int y_padding, int x_padding,
+                               int filter_size, int filter_count,
+                               bool resize_align_corners, string pad_mode,
+                               int stride, string padding) {
+    auto root = tensorflow::Scope::NewRootScope();
+    using namespace ::tensorflow::ops;  // NOLINT(build/namespaces)
+
+    const size_t input_data_size = input_height * input_width * input_depth;
+    Tensor input_data(DT_FLOAT,
+                      TensorShape({1, input_height, input_width, input_depth}));
+    for (int i = 0; i < input_data_size; ++i) {
+      input_data.flat<float>()(i) = i + 1.0f;
+    }
+    Output input =
+        Const(root.WithOpName("input"), Input::Initializer(input_data));
+
+    const size_t filter_data_size =
+        filter_size * filter_size * filter_count * input_depth;
+    Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
+                                              input_depth, filter_count}));
+    for (int i = 0; i < filter_data_size; ++i) {
+      filter_data.flat<float>()(i) = i + 1.0f;
+    }
+    Output filter =
+        Const(root.WithOpName("filter"), Input::Initializer(filter_data));
+
+    Output resize_size =
+        Const(root.WithOpName("resize_size"), {resize_height, resize_width});
+    Output resize =
+        ResizeBilinear(root.WithOpName("resize"), input, resize_size,
+                       ResizeBilinear::AlignCorners(resize_align_corners));
+    Output paddings =
+        Const(root.WithOpName("paddings"),
+              {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
+    Output mirror_pad =
+        MirrorPad(root.WithOpName("mirror_pad"), resize, paddings, pad_mode);
+    Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
+                         {1, stride, stride, 1}, padding);
+
+    Output fused_conv = FusedResizeAndPadConv2D(
+        root.WithOpName("fused_conv"), input, resize_size, paddings, filter,
+        pad_mode, {1, stride, stride, 1}, padding,
+        FusedResizeAndPadConv2D::ResizeAlignCorners(resize_align_corners));
+
+    tensorflow::GraphDef graph;
+    TF_ASSERT_OK(root.ToGraphDef(&graph));
+
+    std::unique_ptr<tensorflow::Session> session(
+        tensorflow::NewSession(tensorflow::SessionOptions()));
+    TF_ASSERT_OK(session->Create(graph));
+
+    std::vector<Tensor> unfused_tensors;
+    TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
+
+    std::vector<Tensor> fused_tensors;
+    TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
+
+    test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
+  }
+};
+
+TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
+
+TEST_F(FusedResizePadConvOpTest, IdentityComparative) {
+  CompareFusedAndSeparate(10, 10, 1, 10, 10, 0, 0, 1, 1, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ConvOnlyComparative) {
+  CompareFusedAndSeparate(10, 10, 3, 10, 10, 0, 0, 4, 4, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeOnlyComparative) {
+  CompareFusedAndSeparate(10, 10, 1, 20, 20, 0, 0, 1, 1, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndConvComparative) {
+  CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvComparative) {
+  CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndConvStridedComparative) {
+  CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, false, "REFLECT", 2,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAlignAndConvValidComparative) {
+  CompareFusedAndSeparate(2, 2, 4, 4, 2, 0, 0, 2, 2, true, "REFLECT", 1,
+                          "VALID");
+}
+
+TEST_F(FusedResizePadConvOpTest, PadOnlyComparative) {
+  CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, PadOnlyWithChannelsComparative) {
+  CompareFusedAndSeparate(4, 4, 3, 4, 4, 2, 2, 1, 1, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndPadComparative) {
+  CompareFusedAndSeparate(4, 4, 1, 6, 6, 2, 2, 1, 1, false, "REFLECT", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, PadOnlySymmetricComparative) {
+  CompareFusedAndSeparate(4, 4, 1, 4, 4, 2, 2, 1, 1, false, "SYMMETRIC", 1,
+                          "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
+  CompareFusedAndSeparate(4, 4, 3, 6, 6, 2, 2, 1, 1, false, "SYMMETRIC", 1,
+                          "SAME");
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops_using_gemm.cc b/tensorflow/core/kernels/conv_ops_using_gemm.cc
index c39510a11a2..6da6da846b4 100644
--- a/tensorflow/core/kernels/conv_ops_using_gemm.cc
+++ b/tensorflow/core/kernels/conv_ops_using_gemm.cc
@@ -56,14 +56,13 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/tensor_slice.h"
 #include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/conv_ops.h"
+#include "tensorflow/core/kernels/gemm_functors.h"
+#include "tensorflow/core/kernels/image_resizer_state.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
-#if defined(__APPLE__)
-#include <Accelerate/Accelerate.h>
-#define USE_ACCELERATE_GEMM
-#endif  // __APPLE__
-
 namespace tensorflow {
 
 namespace {
@@ -189,87 +188,6 @@ class ReferenceConvFunctor {
   }
 };
 
-// A readable but slow implementation of matrix multiplication, useful for
-// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
-// the Im2ColConvFunctor template definition inside the op registration to
-// enable. Assumes row-major ordering of the values in memory.
-template <class T1, class T2, class T3>
-class ReferenceGemmFunctor {
- public:
-  void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
-                  const T2* b, size_t ldb, T3* c, size_t ldc) {
-    const size_t a_i_stride = lda;
-    const size_t a_l_stride = 1;
-    const size_t b_j_stride = 1;
-    const size_t b_l_stride = ldb;
-    const size_t c_i_stride = ldc;
-    const size_t c_j_stride = 1;
-    size_t i, j, l;
-    for (j = 0; j < n; j++) {
-      for (i = 0; i < m; i++) {
-        T3 total(0);
-        for (l = 0; l < k; l++) {
-          const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
-          const T1 a_value = a[a_index];
-          const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
-          const T2 b_value = b[b_index];
-          total += (a_value * b_value);
-        }
-        const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
-        c[c_index] = total;
-      }
-    }
-  }
-};
-
-// Uses the optimized Eigen library to implement the matrix multiplication
-// required by the Im2ColConvFunctor class. We supply the two input and one
-// output types so that the accumulator can potentially be higher-precision than
-// the inputs, even though we don't currently take advantage of this.
-template <class T1, class T2, class T3>
-class FastGemmFunctor {
- public:
-  // Convenience wrappers for the Eigen matrix types we'll be using.
-  typedef Eigen::Map<
-      const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
-      ConstMatrixT1;
-  typedef Eigen::Map<
-      const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
-      ConstMatrixT2;
-  typedef Eigen::Map<
-      Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
-      MatrixT3;
-  void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
-                  const T2* b, size_t ldb, T3* c, size_t ldc) {
-    ConstMatrixT1 a_matrix(a, m, k);
-    ConstMatrixT2 b_matrix(b, k, n);
-    MatrixT3 c_matrix(c, m, n);
-    c_matrix.noalias() = a_matrix * b_matrix;
-  }
-};
-
-// If we have Apple's Accelerate framework, use their implementation of GEMM to
-// get a performance boost for float.
-#if defined(USE_ACCELERATE_GEMM)
-template <>
-class FastGemmFunctor<float, float, float> {
- public:
-  void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
-                  const float* b, size_t ldb, float* c, size_t ldc) {
-    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
-                lda, b, ldb, 0.0f, c, ldc);
-  }
-};
-#endif  // USE_ACCELERATE_GEMM
-
-// Used to keep track of persistent memory buffers used within the op.
-template <class T, size_t size>
-struct Im2ColBufferResource : public ResourceBase {
-  mutex mu;
-  T data[size];
-  string DebugString() { return "Im2ColBufferResource"; }
-};
-
 // Implements convolution as a two stage process, first packing the patches of
 // the input image into columns (im2col) and then running GEMM to produce the
 // final result.
@@ -344,7 +262,6 @@ class Im2ColConvFunctor {
                 errors::InvalidArgument("Im2Col patch too large for buffer"));
     const size_t patches_per_chunk =
         max_chunk_size / (filter_value_count * sizeof(T1));
-
     // Because memory allocation is very expensive on mobile platforms, try to
     // allocate a persistent buffer that will be kept around between calls. We
     // use TensorFlow's resource management to ensure that the memory will be
diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h
new file mode 100644
index 00000000000..d37008d5cfb
--- /dev/null
+++ b/tensorflow/core/kernels/gemm_functors.h
@@ -0,0 +1,105 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// This is a set of different implementations for the basic matrix by matrix
+// multiply function, commonly known as GEMM after the BLAS library's naming.
+// Having a standard interface enables us to swap out implementations on
+// different platforms, to make sure we're using the optimal version. They are
+// implemented as C++ template functors, so they're easy to swap into all of the
+// different kernels that use them.
+
+#include <string.h>
+#include <map>
+#include <vector>
+
+#include "tensorflow/core/framework/tensor.h"
+
+#if defined(__APPLE__) && defined(USE_GEMM_FOR_CONV)
+#include <Accelerate/Accelerate.h>
+#define USE_ACCELERATE_GEMM
+#endif  // __APPLE__
+
+// A readable but slow implementation of matrix multiplication, useful for
+// debugging and understanding the algorithm. Use instead of FastGemmFunctor in
+// the Im2ColConvFunctor template definition inside the op registration to
+// enable. Assumes row-major ordering of the values in memory.
+template <class T1, class T2, class T3>
+class ReferenceGemmFunctor {
+ public:
+  void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
+                  const T2* b, size_t ldb, T3* c, size_t ldc) {
+    const size_t a_i_stride = lda;
+    const size_t a_l_stride = 1;
+    const size_t b_j_stride = 1;
+    const size_t b_l_stride = ldb;
+    const size_t c_i_stride = ldc;
+    const size_t c_j_stride = 1;
+    size_t i, j, l;
+    for (j = 0; j < n; j++) {
+      for (i = 0; i < m; i++) {
+        T3 total(0);
+        for (l = 0; l < k; l++) {
+          const size_t a_index = ((i * a_i_stride) + (l * a_l_stride));
+          const T1 a_value = a[a_index];
+          const size_t b_index = ((j * b_j_stride) + (l * b_l_stride));
+          const T2 b_value = b[b_index];
+          total += (a_value * b_value);
+        }
+        const size_t c_index = ((i * c_i_stride) + (j * c_j_stride));
+        c[c_index] = total;
+      }
+    }
+  }
+};
+
+// Uses the optimized Eigen library to implement the matrix multiplication
+// required by the Im2ColConvFunctor class. We supply the two input and one
+// output types so that the accumulator can potentially be higher-precision than
+// the inputs, even though we don't currently take advantage of this.
+template <class T1, class T2, class T3>
+class FastGemmFunctor {
+ public:
+  // Convenience wrappers for the Eigen matrix types we'll be using.
+  typedef Eigen::Map<
+      const Eigen::Matrix<T1, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+      ConstMatrixT1;
+  typedef Eigen::Map<
+      const Eigen::Matrix<T2, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+      ConstMatrixT2;
+  typedef Eigen::Map<
+      Eigen::Matrix<T3, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+      MatrixT3;
+  void operator()(size_t m, size_t n, size_t k, const T1* a, size_t lda,
+                  const T2* b, size_t ldb, T3* c, size_t ldc) {
+    ConstMatrixT1 a_matrix(a, m, k);
+    ConstMatrixT2 b_matrix(b, k, n);
+    MatrixT3 c_matrix(c, m, n);
+    c_matrix.noalias() = a_matrix * b_matrix;
+  }
+};
+
+// If we have Apple's Accelerate framework, use their implementation of GEMM to
+// get a performance boost for float.
+#if defined(USE_ACCELERATE_GEMM)
+template <>
+class FastGemmFunctor<float, float, float> {
+ public:
+  void operator()(size_t m, size_t n, size_t k, const float* a, size_t lda,
+                  const float* b, size_t ldb, float* c, size_t ldc) {
+    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, 1.0f, a,
+                lda, b, ldb, 0.0f, c, ldc);
+  }
+};
+#endif  // USE_ACCELERATE_GEMM
diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h
index a7acb5e649b..8870937422a 100644
--- a/tensorflow/core/kernels/image_resizer_state.h
+++ b/tensorflow/core/kernels/image_resizer_state.h
@@ -49,12 +49,13 @@ struct ImageResizerState {
   explicit ImageResizerState(bool align_corners)
       : align_corners_(align_corners) {}
 
-  // ValidateAndCreateOutput checks the bounds on the input tensors
+  // ValidateAndCalculateOutputSize checks the bounds on the input tensors
   // and requested size, sets up some of the resizing state such as the
-  // height_scale and width_scale, and allocates the output.
+  // height_scale and width_scale, and calculates the output size.
   // If any of these operations fails, it sets an error status in
   // the context, which the caller must check.
-  void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
+  void ValidateAndCalculateOutputSize(OpKernelContext* context,
+                                      const Tensor& input) {
     OP_REQUIRES(context, input.dims() == 4,
                 errors::InvalidArgument("input must be 4-dimensional",
                                         input.shape().DebugString()));
@@ -87,12 +88,18 @@ struct ImageResizerState {
     OP_REQUIRES(
         context, input.dim_size(1) > 0 && input.dim_size(2) > 0,
         errors::InvalidArgument("input image must be of non-zero size"));
+    height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
+    width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
+  }
+
+  // Calculates all the required variables, and allocates the output.
+  void ValidateAndCreateOutput(OpKernelContext* context, const Tensor& input) {
+    ValidateAndCalculateOutputSize(context, input);
+    if (!context->status().ok()) return;
     OP_REQUIRES_OK(context, context->allocate_output(
                                 0, TensorShape({input.dim_size(0), out_height,
                                                 out_width, input.dim_size(3)}),
                                 &output));
-    height_scale = CalculateResizeScale(in_height, out_height, align_corners_);
-    width_scale = CalculateResizeScale(in_width, out_width, align_corners_);
   }
 
   int64 batch_size;
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index eae5187896e..3baae914cbf 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -185,6 +185,7 @@ class OpsTestBase : public ::testing::Test {
     test::SetOutputAttrs(params_.get(), &attrs);
     checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
     params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
+    params_.get()->resource_manager = device_.get()->resource_manager();
 
     context_.reset(new OpKernelContext(params_.get()));
     device_->Compute(kernel_.get(), context_.get());
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index cc374278e7f..5daaf83133a 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -17,6 +17,7 @@ limitations under the License.
 #include "tensorflow/core/framework/numeric_op.h"
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/util/mirror_pad_mode.h"
 #include "tensorflow/core/util/padding.h"
 #include "tensorflow/core/util/tensor_format.h"
 
@@ -425,6 +426,46 @@ data_format: Specify the data format of the input and output data. With the
         [batch, in_channels, in_height, in_width].
 )doc");
 
+REGISTER_OP("FusedResizeAndPadConv2D")
+    .Input("input: T")
+    .Input("size: int32")
+    .Input("paddings: int32")
+    .Input("filter: T")
+    .Output("output: T")
+    .Attr("T: {half, float, double}")
+    .Attr("resize_align_corners: bool = false")
+    .Attr(GetMirrorPadModeAttrString())
+    .Attr("strides: list(int)")
+    .Attr(GetPaddingAttrString())
+    .Doc(R"doc(
+Performs a resize and padding as a preprocess during a convolution.
+
+It's often possible to do spatial transformations more efficiently as part of
+the packing stage of a convolution, so this op allows for an optimized
+implementation where these stages are fused together. This prevents the need to
+write out the intermediate results as whole tensors, reducing memory pressure,
+and we can get some latency gains by merging the transformation calculations.
+The data_format attribute for Conv2D isn't supported by this op, and defaults to
+'NHWC' order.
+Internally this op uses a single per-graph scratch buffer, which means that it
+will block if multiple versions are being run in parallel. This is because this
+operator is primarily an optimization to minimize memory usage.
+
+input: 4-D with shape `[batch, in_height, in_width, in_channels]`.
+size: A 1-D int32 Tensor of 2 elements: `new_height, new_width`.  The
+  new size for the images.
+paddings: A two-column matrix specifying the padding sizes. The number of
+  rows must be the same as the rank of `input`.
+filter: 4-D with shape
+  `[filter_height, filter_width, in_channels, out_channels]`.
+resize_align_corners: If true, rescale input by (new_height - 1) / (height - 1),
+  which exactly aligns the 4 corners of images and resized images. If false, rescale
+  by new_height / height. Treat similarly the width dimension.
+strides: 1-D of length 4.  The stride of the sliding window for each dimension
+   of `input`. Must be in the same order as the dimension specified with format.
+padding: The type of padding algorithm to use.
+ )doc");
+
 // --------------------------------------------------------------------------
 
 REGISTER_OP("DepthwiseConv2dNative")
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index e14ca1a5593..7bc3ffe25b3 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -844,6 +844,71 @@ ops.RegisterShape("AvgPool")(common_shapes.avg_pool_shape)
 ops.RegisterShape("MaxPool")(common_shapes.max_pool_shape)
 
 
+@ops.RegisterShape("FusedResizeAndPadConv2D")
+def _FusedResizeAndPadConv2DShape(op):
+  """Shape function for FusedResizeAndPadConv2D op."""
+  # The bilinear resize shape calculation.
+  input_shape = op.inputs[0].get_shape().with_rank(4)
+  unused_size_shape = op.inputs[1].get_shape().merge_with([2])
+  size = tensor_util.constant_value(op.inputs[1])
+  if size is not None:
+    height = size[0]
+    width = size[1]
+  else:
+    height = None
+    width = None
+  resized_shape = tensor_shape.TensorShape(
+      [input_shape[0], height, width, input_shape[3]])
+
+  # Calculates the effect of the padding.
+  paddings_shape = op.inputs[2].get_shape().with_rank(2)
+  resized_shape = resized_shape.with_rank(paddings_shape[0].value)
+  paddings_shape = paddings_shape.merge_with(
+      tensor_shape.matrix(resized_shape.ndims, 2))
+  paddings = tensor_util.constant_value(op.inputs[2])
+  if paddings is None:
+    padded_shape = tensor_shape.unknown_shape(ndims=resized_shape.ndims)
+  else:
+    output_dims = []
+    for i, dim in enumerate(resized_shape.dims):
+      if paddings[i, 0] < 0 or paddings[i, 1] < 0:
+        raise ValueError("paddings must be non-negative")
+      output_dims.append(dim + paddings[i, 0] + paddings[i, 1])
+    padded_shape = tensor_shape.TensorShape(output_dims)
+
+  # Finally work out the convolution's effect.
+  filter_shape = op.inputs[3].get_shape().with_rank(4)
+
+  batch_size = padded_shape[0]
+  in_rows = padded_shape[1]
+  in_cols = padded_shape[2]
+
+  filter_rows = filter_shape[0]
+  filter_cols = filter_shape[1]
+  depth_out = filter_shape[3]
+  # Check that the input depths are compatible.
+  padded_shape[3].assert_is_compatible_with(filter_shape[2])
+
+  stride_b, stride_r, stride_c, stride_d = op.get_attr("strides")
+
+  if stride_b != 1 or stride_d != 1:
+    raise ValueError("Current implementation does not yet support "
+                     "strides in the batch and depth dimensions.")
+  # TODO(mrry,shlens): Raise an error if the stride would cause
+  # information in the input to be ignored. This will require a change
+  # in the kernel implementation.
+  padding = op.get_attr("padding")
+  out_rows, out_cols = common_shapes.get2d_conv_output_size(in_rows, in_cols,
+                                                            filter_rows,
+                                                            filter_cols,
+                                                            stride_r,
+                                                            stride_c,
+                                                            padding)
+
+  output_shape = [batch_size, out_rows, out_cols, depth_out]
+  return [tensor_shape.TensorShape(output_shape)]
+
+
 @ops.RegisterShape("MaxPoolWithArgmax")
 def _MaxPoolWithArgMaxShape(op):
   """Shape function for MaxPoolWithArgmax op."""
diff --git a/tensorflow/python/tools/optimize_for_inference.py b/tensorflow/python/tools/optimize_for_inference.py
index a330ff7c508..9c115f53bec 100644
--- a/tensorflow/python/tools/optimize_for_inference.py
+++ b/tensorflow/python/tools/optimize_for_inference.py
@@ -27,6 +27,8 @@ the network is used only for inference. These include:
 
  - Folding batch normalization ops into the pre-calculated weights.
 
+ - Fusing common operations into unified versions.
+
 This script takes a frozen GraphDef file (where the weight variables have been
 converted into constants by the freeze_graph script) and outputs a new GraphDef
 with the optimizations applied.
diff --git a/tensorflow/python/tools/optimize_for_inference_lib.py b/tensorflow/python/tools/optimize_for_inference_lib.py
index 4eb138d97d9..1cb5ba16256 100644
--- a/tensorflow/python/tools/optimize_for_inference_lib.py
+++ b/tensorflow/python/tools/optimize_for_inference_lib.py
@@ -27,6 +27,8 @@ the network is used only for inference. These include:
 
  - Folding batch normalization ops into the pre-calculated weights.
 
+ - Fusing common operations into unified versions.
+
 This script takes a frozen GraphDef file (where the weight variables have been
 converted into constants by the freeze_graph script) and outputs a new GraphDef
 with the optimizations applied.
@@ -37,8 +39,8 @@ bazel build tensorflow/python/tools:optimize_for_inference && \
 bazel-bin/tensorflow/python/tools/optimize_for_inference \
 --input_graph=some_graph_def.pb \
 --output_graph=/tmp/optimized_graph.pb \
---input_node_names=Mul
---output_node_names=softmax
+--input_names=Mul \
+--output_names=softmax
 
 """
 
@@ -74,13 +76,42 @@ def optimize_for_inference(input_graph_def, input_node_names,
   Returns:
     An optimized version of the input graph.
   """
-  stripped_graph_def = strip_unused_lib.strip_unused(input_graph_def,
-                                                     input_node_names,
-                                                     output_node_names,
-                                                     placeholder_type_enum)
-  detrained_graph_def = graph_util.remove_training_nodes(stripped_graph_def)
-  folded_graph_def = fold_batch_norms(detrained_graph_def)
-  return folded_graph_def
+  ensure_graph_is_valid(input_graph_def)
+  optimized_graph_def = input_graph_def
+  optimized_graph_def = strip_unused_lib.strip_unused(optimized_graph_def,
+                                                      input_node_names,
+                                                      output_node_names,
+                                                      placeholder_type_enum)
+  optimized_graph_def = graph_util.remove_training_nodes(optimized_graph_def)
+  optimized_graph_def = fold_batch_norms(optimized_graph_def)
+  optimized_graph_def = fuse_resize_and_conv(optimized_graph_def)
+  ensure_graph_is_valid(optimized_graph_def)
+  return optimized_graph_def
+
+
+def ensure_graph_is_valid(graph_def):
+  """Makes sure that the graph is internally consistent.
+
+  Checks basic properties of the graph def and raises an exception if there are
+  input references to missing nodes, duplicated names, or other logic errors.
+
+  Args:
+    graph_def: Definition of a graph to be checked.
+
+  Raises:
+    ValueError: If the graph is incorrectly constructed.
+  """
+  node_map = {}
+  for node in graph_def.node:
+    if node.name not in node_map.keys():
+      node_map[node.name] = node
+    else:
+      raise ValueError("Duplicate node names detected for ", node.name)
+  for node in graph_def.node:
+    for input_name in node.input:
+      input_node_name = node_name_from_input(input_name)
+      if input_node_name not in node_map.keys():
+        raise ValueError("Input for ", node.name, " not found: ", input_name)
 
 
 def node_name_from_input(node_name):
@@ -161,7 +192,7 @@ def fold_batch_norms(input_graph_def):
     if node.name not in input_node_map.keys():
       input_node_map[node.name] = node
     else:
-      raise ValueError("Duplicate node names detected.")
+      raise ValueError("Duplicate node names detected for ", node.name)
 
   nodes_to_skip = {}
   new_ops = []
@@ -303,3 +334,94 @@ def fold_batch_norms(input_graph_def):
 
   result_graph_def.node.extend(new_ops)
   return result_graph_def
+
+
+def fuse_resize_and_conv(input_graph_def):
+  """Merges preceding resize and mirror pad ops into a specialized convolution.
+
+  There's a common pattern of enlarging the input to a convolution using a
+  resize operation, and also using MirrorPad to extend the boundaries to that
+  zero edge pixels don't bleed inwards when convolving. This routine looks for
+  that pattern of operations, and fuses them together into a Conv2DWithResizeOp.
+
+  Args:
+    input_graph_def: A GraphDef containing a model.
+
+  Returns:
+    Modified graph with resize and pad ops merged.
+
+  Raises:
+    ValueError: If the graph is badly formed with duplicate node names.
+  """
+
+  input_node_map = {}
+  for node in input_graph_def.node:
+    if node.name not in input_node_map.keys():
+      input_node_map[node.name] = node
+    else:
+      raise ValueError("Duplicate node names detected for ", node.name)
+
+  nodes_to_skip = {}
+  new_ops = []
+  for node in input_graph_def.node:
+
+    if node.op != "Conv2D":
+      continue
+    conv_op = node
+
+    input_op = node_from_map(input_node_map, conv_op.input[0])
+    if input_op.op == "MirrorPad":
+      mirror_pad_op = input_op
+      resize_op = node_from_map(input_node_map, mirror_pad_op.input[0])
+    else:
+      mirror_pad_op = None
+      resize_op = input_op
+
+    if resize_op.op != "ResizeBilinear":
+      continue
+
+    nodes_to_skip[conv_op.name] = True
+    if mirror_pad_op:
+      nodes_to_skip[mirror_pad_op.name] = True
+    nodes_to_skip[resize_op.name] = True
+
+    fused_conv_op = tf.NodeDef()
+    fused_conv_op.op = "FusedResizeAndPadConv2D"
+    fused_conv_op.name = conv_op.name
+    if mirror_pad_op:
+      mirror_paddings_name = mirror_pad_op.input[1]
+      mirror_paddings_mode = mirror_pad_op.attr["mode"]
+    else:
+      # If there was no MirrorPad op, then create settings that make the padding
+      # stage of the fused operation a no-op.
+      paddings_op = tf.NodeDef()
+      paddings_op.op = "Const"
+      paddings_op.name = conv_op.name + "_dummy_paddings"
+      paddings_op.attr["dtype"].CopyFrom(tf.AttrValue(
+          type=tf.int32.as_datatype_enum))
+      paddings_op.attr["value"].CopyFrom(tf.AttrValue(
+          tensor=tensor_util.make_tensor_proto(
+              [0, 0, 0, 0, 0, 0, 0, 0], tf.int32, [4, 2])))
+      new_ops.extend([paddings_op])
+      mirror_paddings_name = paddings_op.name
+      mirror_paddings_mode = tf.AttrValue(s=b"REFLECT")
+    fused_conv_op.input.extend([resize_op.input[0], resize_op.input[1],
+                                mirror_paddings_name, conv_op.input[1]])
+    fused_conv_op.attr["T"].CopyFrom(conv_op.attr["T"])
+    fused_conv_op.attr["resize_align_corners"].CopyFrom(
+        resize_op.attr["align_corners"])
+    fused_conv_op.attr["mode"].CopyFrom(mirror_paddings_mode)
+    fused_conv_op.attr["strides"].CopyFrom(conv_op.attr["strides"])
+    fused_conv_op.attr["padding"].CopyFrom(conv_op.attr["padding"])
+    new_ops.extend([fused_conv_op])
+
+  result_graph_def = tf.GraphDef()
+  for node in input_graph_def.node:
+    if node.name in nodes_to_skip:
+      continue
+    new_node = tf.NodeDef()
+    new_node.CopyFrom(node)
+    result_graph_def.node.extend([new_node])
+
+  result_graph_def.node.extend(new_ops)
+  return result_graph_def
diff --git a/tensorflow/python/tools/optimize_for_inference_test.py b/tensorflow/python/tools/optimize_for_inference_test.py
index 61644fe9c91..d92d7ab8c7d 100644
--- a/tensorflow/python/tools/optimize_for_inference_test.py
+++ b/tensorflow/python/tools/optimize_for_inference_test.py
@@ -54,6 +54,7 @@ class OptimizeForInferenceTest(tf.test.TestCase):
                                              shape=shape)))
 
   def testOptimizeForInference(self):
+    unused_constant_name = "unused_constant"
     unconnected_add_name = "unconnected_add"
     a_constant_name = "a_constant"
     b_constant_name = "b_constant"
@@ -64,9 +65,14 @@ class OptimizeForInferenceTest(tf.test.TestCase):
     add_name = "add"
     unused_output_add_name = "unused_output_add"
     graph_def = tf.GraphDef()
+    unused_constant = self.create_constant_node_def(unused_constant_name,
+                                                    value=0,
+                                                    dtype=tf.float32,
+                                                    shape=[])
+    graph_def.node.extend([unused_constant])
     unconnected_add_node = self.create_node_def("Add", unconnected_add_name,
-                                                ["no_such_node",
-                                                 "no_such_node"])
+                                                [unused_constant_name,
+                                                 unused_constant_name])
     self.set_attr_dtype(unconnected_add_node, "T", tf.float32)
     graph_def.node.extend([unconnected_add_node])
     a_constant = self.create_constant_node_def(a_constant_name,
@@ -160,6 +166,65 @@ class OptimizeForInferenceTest(tf.test.TestCase):
     for node in optimized_graph_def.node:
       self.assertNotEqual("BatchNormWithGlobalNormalization", node.op)
 
+  def testFuseResizePadAndConv(self):
+    with self.test_session() as sess:
+      inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+      input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
+                             dtype=tf.float32)
+      resize_op = tf.image.resize_bilinear(input_op, [12, 4],
+                                           align_corners=False)
+      pad_op = tf.pad(resize_op, [[0, 0], [1, 1], [2, 2], [0, 0]],
+                      mode="REFLECT")
+      weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+      weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
+                               dtype=tf.float32)
+      tf.nn.conv2d(pad_op, weights_op, [1, 1, 1, 1],
+                   padding="VALID", name="output")
+      original_graph_def = sess.graph_def
+      original_result = sess.run(["output:0"])
+    optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
+        original_graph_def)
+
+    with self.test_session() as sess:
+      _ = tf.import_graph_def(optimized_graph_def, input_map={},
+                              name="optimized")
+      optimized_result = sess.run(["optimized/output:0"])
+
+    self.assertAllClose(original_result, optimized_result)
+
+    for node in optimized_graph_def.node:
+      self.assertNotEqual("Conv2D", node.op)
+      self.assertNotEqual("MirrorPad", node.op)
+      self.assertNotEqual("ResizeBilinear", node.op)
+
+  def testFuseResizeAndConv(self):
+    with self.test_session() as sess:
+      inputs = [1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6]
+      input_op = tf.constant(np.array(inputs), shape=[1, 2, 3, 2],
+                             dtype=tf.float32)
+      resize_op = tf.image.resize_bilinear(input_op, [12, 4],
+                                           align_corners=False)
+      weights = [1, 2, 3, 4, 0.1, 0.2, 0.3, 0.4]
+      weights_op = tf.constant(np.array(weights), shape=[1, 2, 2, 2],
+                               dtype=tf.float32)
+      tf.nn.conv2d(resize_op, weights_op, [1, 1, 1, 1],
+                   padding="VALID", name="output")
+      original_graph_def = sess.graph_def
+      original_result = sess.run(["output:0"])
+    optimized_graph_def = optimize_for_inference_lib.fuse_resize_and_conv(
+        original_graph_def)
+
+    with self.test_session() as sess:
+      _ = tf.import_graph_def(optimized_graph_def, input_map={},
+                              name="optimized")
+      optimized_result = sess.run(["optimized/output:0"])
+
+    self.assertAllClose(original_result, optimized_result)
+
+    for node in optimized_graph_def.node:
+      self.assertNotEqual("Conv2D", node.op)
+      self.assertNotEqual("ResizeBilinear", node.op)
+
 
 if __name__ == "__main__":
   tf.test.main()