From 618269e6f4ce5d5f47305e06c24576bf5aaaf56d Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Thu, 14 Feb 2019 11:17:57 -0800 Subject: [PATCH 1/5] Enabled quantized pad fusion. --- tensorflow/core/graph/mkl_layout_pass.cc | 9 +- tensorflow/core/kernels/BUILD | 23 + tensorflow/core/kernels/mkl_conv_ops.cc | 111 +++-- tensorflow/core/kernels/mkl_conv_ops.h | 5 +- .../kernels/mkl_quantized_conv_ops_test.cc | 454 ++++++++++++++++++ tensorflow/core/ops/mkl_nn_ops.cc | 11 + tensorflow/core/ops/nn_ops.cc | 23 +- 7 files changed, 577 insertions(+), 59 deletions(-) create mode 100644 tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index e934978e76a..59cdc4afe4b 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2400,7 +2400,8 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, DataType Tinput, Tfilter, out_type; string padding; string data_format("NHWC"); - std::vector strides, dilations; + std::vector strides, dilations, padding_list; + bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list"); // Get all attributes from old node. TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput)); @@ -2409,6 +2410,9 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); + if (has_padding_list) { + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list)); + } Node* filter_node = nullptr; orig_node->input_node(1, &filter_node); @@ -2423,6 +2427,9 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, nb->Attr("dilations", dilations); nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion. nb->Attr("data_format", data_format); + if (has_padding_list) { + nb->Attr("padding_list", padding_list); + } // Requantization attr Tbias. DataType Tbias; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 445fde84f1f..cf59fb005ce 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6388,6 +6388,29 @@ tf_cc_test( ], ) +tf_cc_test_mkl( + name = "mkl_quantized_conv_ops_test", + size = "small", + srcs = ["mkl_quantized_conv_ops_test.cc"], + tags = ["nomsan"], # http://b/32242946 + deps = [ + ":mkl_conv_op", + ":mkl_input_conversion_op", + ":ops_testutil", + ":ops_util", + ":quantization_utils", + ":quantized_ops", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "quantize_op_test", size = "small", diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 0134cc22356..aeb198c145b 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -467,19 +467,18 @@ class MklConvOp : public OpKernel { filter.shape().DebugString())); for (int i = 0; i < 3; i++) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), - errors::InvalidArgument("filter too large")); + OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), + std::numeric_limits::max()), + errors::InvalidArgument("filter too large")); } const int64 input_depth = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C') : GetTensorDim(input, data_format_, 'C'); - OP_REQUIRES(context, input_depth == filter.dim_size(2), - errors::InvalidArgument( - "input and filter must have the same depth: ", input_depth, - " vs ", filter.dim_size(2))); + OP_REQUIRES( + context, input_depth == filter.dim_size(2), + errors::InvalidArgument("input and filter must have the same depth: ", + input_depth, " vs ", filter.dim_size(2))); // The last dimension for filter is out_depth. const int out_depth = static_cast(filter.dim_size(3)); @@ -488,10 +487,9 @@ class MklConvOp : public OpKernel { const int64 input_rows_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H') : GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits::max()), - errors::InvalidArgument("Input rows too large")); + OP_REQUIRES(context, FastBoundsCheck(input_rows_raw, + std::numeric_limits::max()), + errors::InvalidArgument("Input rows too large")); const int input_rows = static_cast(input_rows_raw); const int filter_rows = static_cast(filter.dim_size(0)); @@ -500,10 +498,9 @@ class MklConvOp : public OpKernel { const int64 input_cols_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W') : GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits::max()), - errors::InvalidArgument("Input cols too large")); + OP_REQUIRES(context, FastBoundsCheck(input_cols_raw, + std::numeric_limits::max()), + errors::InvalidArgument("Input cols too large")); const int input_cols = static_cast(input_cols_raw); const int filter_cols = static_cast(filter.dim_size(1)); @@ -511,10 +508,9 @@ class MklConvOp : public OpKernel { const int64 input_batch_raw = input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N') : GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES( - context, - FastBoundsCheck(input_batch_raw, std::numeric_limits::max()), - errors::InvalidArgument("batch is too large")); + OP_REQUIRES(context, FastBoundsCheck(input_batch_raw, + std::numeric_limits::max()), + errors::InvalidArgument("batch is too large")); const int batch = static_cast(input_batch_raw); // For now we take the stride from the second and third dimensions only (we @@ -860,6 +856,9 @@ class MklConvOp : public OpKernel { explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + if (context->HasAttr("padding_list")) { + OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); + } OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); @@ -899,17 +898,15 @@ class MklConvOp : public OpKernel { OP_REQUIRES(context, dilations_.size() == 5, errors::InvalidArgument("Dilation rates field must " "specify 5 dimensions")); - OP_REQUIRES(context, - (GetTensorDim(dilations_, data_format_, 'N') == 1 && - GetTensorDim(dilations_, data_format_, 'C') == 1), + OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && + GetTensorDim(dilations_, data_format_, 'C') == 1), errors::InvalidArgument( "Current implementation does not yet support " "dilations rates in the batch and depth dimensions.")); OP_REQUIRES( - context, - (GetTensorDim(dilations_, data_format_, '0') > 0 && - GetTensorDim(dilations_, data_format_, '1') > 0 && - GetTensorDim(dilations_, data_format_, '2') > 0), + context, (GetTensorDim(dilations_, data_format_, '0') > 0 && + GetTensorDim(dilations_, data_format_, '1') > 0 && + GetTensorDim(dilations_, data_format_, '2') > 0), errors::InvalidArgument("Dilated rates should be larger than 0.")); } } @@ -938,9 +935,18 @@ class MklConvOp : public OpKernel { dilations, strides; memory::dims dst_dims_tf_order, dst_dims_mkl_order; - // If pad with conv2d fusion is enabled - if (fuse_pad_) { - PadWithConvFusion(context, padding_left, padding_right); + // For Quantized-Conv2D and Pad fusion, we get padding from the + // `padding_list` attribute. Otherwise, we get it from one of the inputs. + bool quantized_pad_enabled = false; + for (auto const& padding_val : padding_list_) { + if (!padding_val) continue; + quantized_pad_enabled = true; + break; + } + + if (fuse_pad_ || quantized_pad_enabled) { + PadWithConvFusion(context, padding_left, padding_right, + quantized_pad_enabled); } // Get shapes of input tensors in MKL-DNN order @@ -951,7 +957,8 @@ class MklConvOp : public OpKernel { conv_utl.GetConvFwdSizesInMklOrder( src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left, - &padding_right, fuse_pad_, is_depthwise); + &padding_right, (fuse_pad_ || quantized_pad_enabled), is_depthwise); + if (!context->status().ok()) return; // Check for corner case - if there is nothing to compute, return. @@ -1151,16 +1158,20 @@ class MklConvOp : public OpKernel { } void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left, - memory::dims& padding_right) { + memory::dims& padding_right, + bool quantized_pad_enabled) { const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); - OP_REQUIRES(context, paddings_tf.dims() == 2, - errors::InvalidArgument("paddings must be 2-dimensional: ", - paddings_tf.shape().DebugString())); - - // Flatten tensor to get individual paddings. - Tpadding* paddings = static_cast( - const_cast(paddings_tf.flat().data())); - + Tpadding* paddings = nullptr; + if (quantized_pad_enabled) { + paddings = padding_list_.data(); + } else { + OP_REQUIRES(context, paddings_tf.dims() == 2, + errors::InvalidArgument("paddings must be 2-dimensional: ", + paddings_tf.shape().DebugString())); + // Flatten tensor to get individual paddings. + paddings = static_cast( + const_cast(paddings_tf.flat().data())); + } // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf) // will be zero. // Example: @@ -1186,8 +1197,7 @@ class MklConvOp : public OpKernel { pad_left = paddings[6]; pad_right = paddings[7]; } - - // Create padding arrays for MKL DNN convolutions. + // Create padding arrays for MKL-DNN convolutions. // MKL-DNN uses asymetric padding. padding_left = {static_cast(pad_top), static_cast(pad_left)}; padding_right = {static_cast(pad_bottom), static_cast(pad_right)}; @@ -1264,6 +1274,7 @@ class MklConvOp : public OpKernel { private: std::vector strides_; std::vector dilations_; + std::vector padding_list_; bool is_filter_const_; mutex mu_; Padding padding_; @@ -1792,8 +1803,8 @@ class MklQuantizedConv2DSumReluOp const float max_filter = context->input(5 + bias_index_offset).flat()(0); - reorder_sum_scale = 255.0 * 127.0 / - (std::max(std::abs(max_input), std::abs(min_input)) * + reorder_sum_scale = + 255.0 * 127.0 / (std::max(std::abs(max_input), std::abs(min_input)) * std::max(std::abs(max_filter), std::abs(min_filter))); std::vector scales; scales.push_back(reorder_sum_scale); @@ -1825,7 +1836,7 @@ class MklQuantizedConv2DSumReluOp }; // INT8 kernel registration -// Register NoOp kernel for QunatizedConv2D for qint8 filter +// Register NoOp kernel for QuantizedConv2D for qint8 filter REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D") .Device(DEVICE_CPU) .TypeConstraint("Tinput") @@ -1840,7 +1851,7 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize") .TypeConstraint("out_type"), NoOp); -// Register a templatized implementation of MklQuntizedConv2D. +// Register a templatized implementation of MklQuantizedConv2D. REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2D") .Device(DEVICE_CPU) @@ -2029,17 +2040,21 @@ REGISTER_KERNEL_BUILDER( .Device(DEVICE_CPU) .TypeConstraint("Tinput") .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DSumReluOp); + REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Device(DEVICE_CPU) .TypeConstraint("Tinput") .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DSumReluOp); + #endif // INTEL_MKL_ML // Register 2D operations @@ -2074,7 +2089,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tpaddings") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConvOp); \ + float, int32, false, true, false>); \ REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ @@ -2116,7 +2131,7 @@ TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); .TypeConstraint("T") \ .TypeConstraint("Tpaddings") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklFusedConvOp); \ + MklFusedConvOp); \ REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithFusedConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index c12a4ff0f0c..228397c7120 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -196,9 +196,8 @@ class MklDnnConvUtil { filter_shape.DebugString())); for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { - OP_REQUIRES(context_, - FastBoundsCheck(filter_shape.dim_size(i), - std::numeric_limits::max()), + OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), + std::numeric_limits::max()), errors::InvalidArgument("filter too large")); } diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc new file mode 100644 index 00000000000..1d41b251930 --- /dev/null +++ b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc @@ -0,0 +1,454 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// TODO(bhavanis): Move ConvMklToTF to mkl_test_util.h as it is used by +// most unit tests. + +// Helper class for converting MKL tensors to TF tensors and comparing to +// expected values + +static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0}; +static const TensorShape dummy_shape({8}); + +class ConvMklToTF : public OpsTestBase { + public: + template + void ConvertMklToTF(DataType dtype, const Tensor& input, + const Tensor& input_metadata_tensor, Tensor& output) { + // Create an MKL to TF conversion node and execute it + TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(DT_UINT8)) // MKL metadata tensor + .Attr("T", dtype) + .Attr("_kernel", "MklOp") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(input.shape(), input.flat()); + AddInputFromArray(input_metadata_tensor.shape(), + input_metadata_tensor.flat()); + TF_ASSERT_OK(RunOpKernel()); + + output = *GetOutput(0); + } + void TestBody() {} +}; + +class QuantizedConv2DTest : public OpsTestBase { + protected: + void ConfigureQuantizedConv2D(const int& stride = 1) { + TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +// Output -> float +TEST_F(QuantizedConv2DTest, Small) { + const int stride = 1; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + + // Image -> uint8 + const float image_min = 0.0f; + const float image_max = 255.0f; + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + Tensor image_float(DT_FLOAT, + {image_batch_count, image_height, image_width, depth}); + test::FillValues(&image_float, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor image_quantized = + FloatTensorToQuantized(image_float, image_min, image_max); + + const int filter_size = 3; + const int filter_count = 1; + + // Filter -> int8 with symmetric range + const float filter_min = -127.0f; + const float filter_max = 127.0f; + + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + Tensor filter_float(DT_FLOAT, + {filter_size, filter_size, depth, filter_count}); + test::FillValues(&filter_float, {1, 4, 7, 2, 5, 8, 3, 6, 9}); + Tensor filter_quantized = + FloatTensorToQuantized(filter_float, filter_min, filter_max); + + AddInputFromArray(image_quantized.shape(), + image_quantized.flat()); + AddInputFromArray(filter_quantized.shape(), + filter_quantized.flat()); + AddInputFromArray(TensorShape({1}), {image_min}); + AddInputFromArray(TensorShape({1}), {image_max}); + AddInputFromArray(TensorShape({1}), {filter_min}); + AddInputFromArray(TensorShape({1}), {filter_max}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*8)+(7*0)+(2*11)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + + // Output -> float + const int expected_width = image_width; + const int expected_height = image_height * filter_count; + Tensor expected_float( + DT_FLOAT, TensorShape({image_batch_count, expected_height, expected_width, + filter_count})); + test::FillValues(&expected_float, {105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + Tensor output_float = + QuantizedTensorToFloat(output_quantized, output_min, output_max); + + test::ExpectTensorNear(expected_float, output_float, 1.0); +} + +// Output -> int32 +TEST_F(QuantizedConv2DTest, Small32Bit) { + const int stride = 1; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {10, 40, 70, 20, 50, 80, 30, 60, 90}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> int32 + const int expected_width = image_width; + const int expected_height = image_height * filter_count; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues( + &expected, {10500, 15000, 18300, 9500, 23500, 31200, 35700, 17800, 18700, + 23400, 26100, 12100}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +// Output -> int32 +TEST_F(QuantizedConv2DTest, Small32BitWithPadding) { + const int stride = 1; + TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("padding_list", {0, 0, 1, 1, 1, 1, 0, 0}) + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {10, 40, 70, 20, 50, 80, 30, 60, 90}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> int32 + const int expected_width = image_width; + const int expected_height = image_height * filter_count; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues( + &expected, {10500, 15000, 18300, 9500, 23500, 31200, 35700, 17800, 18700, + 23400, 26100, 12100}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +// Output -> int32 +TEST_F(QuantizedConv2DTest, OddPadding) { + const int stride = 2; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 4; + const int image_batch_count = 1; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> int32 + const int expected_width = image_width / stride; + const int expected_height = (image_height * filter_count) / stride; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues(&expected, {348, 252, 274, 175}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +// Output -> int32 +TEST_F(QuantizedConv2DTest, OddPaddingBatch) { + const int stride = 2; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 4; + const int image_batch_count = 3; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> int32 + const int expected_width = image_width / stride; + const int expected_height = (image_height * filter_count) / stride; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues( + &expected, {348, 252, 274, 175, 348, 252, 274, 175, 348, 252, 274, 175}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index b23c3735665..0e6ad9162a5 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -206,6 +206,7 @@ REGISTER_OP("_MklQuantizedConv2D") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -251,6 +252,7 @@ REGISTER_OP("_MklQuantizedConv2DAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -296,6 +298,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBias") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -345,6 +348,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -387,6 +391,7 @@ REGISTER_OP("_MklQuantizedConv2DAndRelu") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -432,6 +437,7 @@ REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -477,6 +483,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -526,6 +533,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -574,6 +582,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -630,6 +639,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -688,6 +698,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index ef7a65c0113..7546e96e0f8 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -1213,9 +1213,9 @@ Status TopKShapeFn(InferenceContext* c) { DimensionHandle last_dim = c->Dim(input, -1); if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) && c->Value(last_dim) < c->Value(k_dim)) { - return errors::InvalidArgument( - "input must have last dimension >= k = ", c->Value(k_dim), " but is ", - c->Value(last_dim)); + return errors::InvalidArgument("input must have last dimension >= k = ", + c->Value(k_dim), " but is ", + c->Value(last_dim)); } // Replace last_dim with k_dim. @@ -1269,9 +1269,9 @@ REGISTER_OP("NthElement") DimensionHandle last_dim = c->Dim(input, -1); if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) && c->Value(last_dim) <= c->Value(n_dim)) { - return errors::InvalidArgument( - "Input must have last dimension > n = ", c->Value(n_dim), - " but is ", c->Value(last_dim)); + return errors::InvalidArgument("Input must have last dimension > n = ", + c->Value(n_dim), " but is ", + c->Value(last_dim)); } // Reduce last_dim for output tensor @@ -2560,6 +2560,7 @@ REGISTER_OP("QuantizedConv2DAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2592,6 +2593,7 @@ REGISTER_OP("QuantizedConv2DWithBias") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2625,6 +2627,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2657,6 +2660,7 @@ REGISTER_OP("QuantizedConv2DAndRelu") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2687,6 +2691,7 @@ REGISTER_OP("QuantizedConv2DAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2719,6 +2724,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRelu") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2753,6 +2759,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2787,6 +2794,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2824,6 +2832,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2863,6 +2872,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2878,5 +2888,4 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") return Status::OK(); }); - } // namespace tensorflow From 0a507dc55b51e4f5651b57473ff5e2b3ea50d95e Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Fri, 15 Feb 2019 16:19:25 -0800 Subject: [PATCH 2/5] Addressed review comments. --- tensorflow/core/kernels/mkl_conv_ops.cc | 9 +++--- .../kernels/mkl_quantized_conv_ops_test.cc | 30 +++++++++++-------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index aeb198c145b..51b37480a78 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -466,7 +466,7 @@ class MklConvOp : public OpKernel { errors::InvalidArgument("filter must be 4-dimensional: ", filter.shape().DebugString())); - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 3; ++i) { OP_REQUIRES(context, FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), errors::InvalidArgument("filter too large")); @@ -939,9 +939,10 @@ class MklConvOp : public OpKernel { // `padding_list` attribute. Otherwise, we get it from one of the inputs. bool quantized_pad_enabled = false; for (auto const& padding_val : padding_list_) { - if (!padding_val) continue; - quantized_pad_enabled = true; - break; + if (padding_val) { + quantized_pad_enabled = true; + break; + } } if (fuse_pad_ || quantized_pad_enabled) { diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc index 1d41b251930..2e599d3d9f8 100644 --- a/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc @@ -180,7 +180,7 @@ TEST_F(QuantizedConv2DTest, Small) { // Output -> float const int expected_width = image_width; - const int expected_height = image_height * filter_count; + const int expected_height = image_height; Tensor expected_float( DT_FLOAT, TensorShape({image_batch_count, expected_height, expected_width, filter_count})); @@ -203,11 +203,13 @@ TEST_F(QuantizedConv2DTest, Small) { test::ExpectTensorNear(expected_float, output_float, 1.0); } -// Output -> int32 +// Output -> qint32 TEST_F(QuantizedConv2DTest, Small32Bit) { const int stride = 1; ConfigureQuantizedConv2D(stride); + // The illustrations and details regarding inputs and outputs + // are in TEST_F(QuantizedConv2DTest, Small) const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -239,9 +241,9 @@ TEST_F(QuantizedConv2DTest, Small32Bit) { TF_ASSERT_OK(RunOpKernel()); - // Output -> int32 + // Output -> qint32 const int expected_width = image_width; - const int expected_height = image_height * filter_count; + const int expected_height = image_height; Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, expected_width, filter_count})); test::FillValues( @@ -259,7 +261,7 @@ TEST_F(QuantizedConv2DTest, Small32Bit) { test::ExpectTensorEqual(expected, output_quantized); } -// Output -> int32 +// Output -> qint32 TEST_F(QuantizedConv2DTest, Small32BitWithPadding) { const int stride = 1; TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D") @@ -288,6 +290,8 @@ TEST_F(QuantizedConv2DTest, Small32BitWithPadding) { .Finalize(node_def())); TF_ASSERT_OK(InitOp()); + // The illustrations and details regarding inputs and outputs + // are in TEST_F(QuantizedConv2DTest, Small) const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -319,9 +323,9 @@ TEST_F(QuantizedConv2DTest, Small32BitWithPadding) { TF_ASSERT_OK(RunOpKernel()); - // Output -> int32 + // Output -> qint32 const int expected_width = image_width; - const int expected_height = image_height * filter_count; + const int expected_height = image_height; Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, expected_width, filter_count})); test::FillValues( @@ -339,7 +343,7 @@ TEST_F(QuantizedConv2DTest, Small32BitWithPadding) { test::ExpectTensorEqual(expected, output_quantized); } -// Output -> int32 +// Output -> qint32 TEST_F(QuantizedConv2DTest, OddPadding) { const int stride = 2; ConfigureQuantizedConv2D(stride); @@ -375,9 +379,9 @@ TEST_F(QuantizedConv2DTest, OddPadding) { TF_ASSERT_OK(RunOpKernel()); - // Output -> int32 + // Output -> qint32 const int expected_width = image_width / stride; - const int expected_height = (image_height * filter_count) / stride; + const int expected_height = image_height / stride; Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, expected_width, filter_count})); test::FillValues(&expected, {348, 252, 274, 175}); @@ -393,7 +397,7 @@ TEST_F(QuantizedConv2DTest, OddPadding) { test::ExpectTensorEqual(expected, output_quantized); } -// Output -> int32 +// Output -> qint32 TEST_F(QuantizedConv2DTest, OddPaddingBatch) { const int stride = 2; ConfigureQuantizedConv2D(stride); @@ -431,9 +435,9 @@ TEST_F(QuantizedConv2DTest, OddPaddingBatch) { TF_ASSERT_OK(RunOpKernel()); - // Output -> int32 + // Output -> qint32 const int expected_width = image_width / stride; - const int expected_height = (image_height * filter_count) / stride; + const int expected_height = image_height / stride; Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, expected_width, filter_count})); test::FillValues( From b3c29b2638605a15bdb49ed6eab7ce36c16a51b2 Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Tue, 19 Feb 2019 16:31:28 -0800 Subject: [PATCH 3/5] Changed type of 'padding_list_' to 'Tpadding'. Added additional registrations for QuantizedConv2D + bias + sum (and signed sum) + relu & requantize fusion for when Tbias is a float. --- tensorflow/core/kernels/mkl_conv_ops.cc | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 51b37480a78..5932237b944 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -1275,7 +1275,7 @@ class MklConvOp : public OpKernel { private: std::vector strides_; std::vector dilations_; - std::vector padding_list_; + std::vector padding_list_; bool is_filter_const_; mutex mu_; Padding padding_; @@ -2056,6 +2056,25 @@ REGISTER_KERNEL_BUILDER( .Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DSumReluOp); +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DSumReluOp); + +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DSumReluOp); #endif // INTEL_MKL_ML // Register 2D operations @@ -2090,7 +2109,7 @@ REGISTER_KERNEL_BUILDER( .TypeConstraint("Tpaddings") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConvOp); \ + float, int64, false, true, false>); \ REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ @@ -2132,7 +2151,7 @@ TF_CALL_float(REGISTER_MKL_CPU_2D_DEPTHWISE); .TypeConstraint("T") \ .TypeConstraint("Tpaddings") \ .Label(mkl_op_registry::kMklOpLabel), \ - MklFusedConvOp); \ + MklFusedConvOp); \ REGISTER_KERNEL_BUILDER(Name("__MklDummyPadWithFusedConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint("T") \ From cf87abf85dc6f45ec91a858e197b3615fede665b Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Tue, 19 Feb 2019 22:26:41 -0800 Subject: [PATCH 4/5] Updated the API golden files since 'padding_list' attribute was added to quantized Conv2D ops. --- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 144d5644728..f340c9a43ae 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2270,43 +2270,43 @@ tf_module { } member_method { name: "QuantizedConv2DAndRelu" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBias" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedInstanceNorm" @@ -2655,7 +2655,7 @@ tf_module { member_method { name: "Requantize" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None" - } + } member_method { name: "RequantizePerChannel" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None" From dcbbefa8c703c4091951adc6fec7bba18bae29cb Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Tue, 19 Feb 2019 23:36:06 -0800 Subject: [PATCH 5/5] Added v1 golden file for 'api_compatibility_test' --- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 58d3482e5fc..e0760e5bbfd 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2278,43 +2278,43 @@ tf_module { } member_method { name: "QuantizedConv2DAndRelu" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBias" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedInstanceNorm"