diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index c66dfe10f55..660cfc7960d 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -2388,7 +2388,8 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, DataType Tinput, Tfilter, out_type; string padding; string data_format("NHWC"); - std::vector strides, dilations; + std::vector strides, dilations, padding_list; + bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list"); // Get all attributes from old node. TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput)); @@ -2397,6 +2398,9 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides)); TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations)); + if (has_padding_list) { + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list)); + } Node* filter_node = nullptr; orig_node->input_node(1, &filter_node); @@ -2411,6 +2415,9 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node, nb->Attr("dilations", dilations); nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion. nb->Attr("data_format", data_format); + if (has_padding_list) { + nb->Attr("padding_list", padding_list); + } // Requantization attr Tbias. DataType Tbias; diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 3b3ce6cd22c..59bec44401d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -6327,6 +6327,29 @@ tf_cc_test( ], ) +tf_cc_test_mkl( + name = "mkl_quantized_conv_ops_test", + size = "small", + srcs = ["mkl_quantized_conv_ops_test.cc"], + tags = ["nomsan"], # http://b/32242946 + deps = [ + ":mkl_conv_op", + ":mkl_input_conversion_op", + ":ops_testutil", + ":ops_util", + ":quantization_utils", + ":quantized_ops", + "//tensorflow/core:array_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:math_ops_op_lib", + "//tensorflow/core:nn_ops_op_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_cc_test( name = "quantize_op_test", size = "small", diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 0134cc22356..da999d28b1f 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -466,7 +466,7 @@ class MklConvOp : public OpKernel { errors::InvalidArgument("filter must be 4-dimensional: ", filter.shape().DebugString())); - for (int i = 0; i < 3; i++) { + for (int i = 0; i < 3; ++i) { OP_REQUIRES( context, FastBoundsCheck(filter.dim_size(i), std::numeric_limits::max()), @@ -860,6 +860,9 @@ class MklConvOp : public OpKernel { explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); + if (context->HasAttr("padding_list")) { + OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_)); + } OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); @@ -938,9 +941,19 @@ class MklConvOp : public OpKernel { dilations, strides; memory::dims dst_dims_tf_order, dst_dims_mkl_order; - // If pad with conv2d fusion is enabled - if (fuse_pad_) { - PadWithConvFusion(context, padding_left, padding_right); + // For Quantized-Conv2D and Pad fusion, we get padding from the + // `padding_list` attribute. Otherwise, we get it from one of the inputs. + bool quantized_pad_enabled = false; + for (auto const& padding_val : padding_list_) { + if (padding_val) { + quantized_pad_enabled = true; + break; + } + } + + if (fuse_pad_ || quantized_pad_enabled) { + PadWithConvFusion(context, padding_left, padding_right, + quantized_pad_enabled); } // Get shapes of input tensors in MKL-DNN order @@ -951,7 +964,8 @@ class MklConvOp : public OpKernel { conv_utl.GetConvFwdSizesInMklOrder( src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides, &dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left, - &padding_right, fuse_pad_, is_depthwise); + &padding_right, (fuse_pad_ || quantized_pad_enabled), is_depthwise); + if (!context->status().ok()) return; // Check for corner case - if there is nothing to compute, return. @@ -1151,16 +1165,20 @@ class MklConvOp : public OpKernel { } void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left, - memory::dims& padding_right) { + memory::dims& padding_right, + bool quantized_pad_enabled) { const Tensor& paddings_tf = MklGetInput(context, input_index_pad_); - OP_REQUIRES(context, paddings_tf.dims() == 2, - errors::InvalidArgument("paddings must be 2-dimensional: ", - paddings_tf.shape().DebugString())); - - // Flatten tensor to get individual paddings. - Tpadding* paddings = static_cast( - const_cast(paddings_tf.flat().data())); - + Tpadding* paddings = nullptr; + if (quantized_pad_enabled) { + paddings = padding_list_.data(); + } else { + OP_REQUIRES(context, paddings_tf.dims() == 2, + errors::InvalidArgument("paddings must be 2-dimensional: ", + paddings_tf.shape().DebugString())); + // Flatten tensor to get individual paddings. + paddings = static_cast( + const_cast(paddings_tf.flat().data())); + } // If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf) // will be zero. // Example: @@ -1186,8 +1204,7 @@ class MklConvOp : public OpKernel { pad_left = paddings[6]; pad_right = paddings[7]; } - - // Create padding arrays for MKL DNN convolutions. + // Create padding arrays for MKL-DNN convolutions. // MKL-DNN uses asymetric padding. padding_left = {static_cast(pad_top), static_cast(pad_left)}; padding_right = {static_cast(pad_bottom), static_cast(pad_right)}; @@ -1264,6 +1281,7 @@ class MklConvOp : public OpKernel { private: std::vector strides_; std::vector dilations_; + std::vector padding_list_; bool is_filter_const_; mutex mu_; Padding padding_; @@ -1825,7 +1843,7 @@ class MklQuantizedConv2DSumReluOp }; // INT8 kernel registration -// Register NoOp kernel for QunatizedConv2D for qint8 filter +// Register NoOp kernel for QuantizedConv2D for qint8 filter REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D") .Device(DEVICE_CPU) .TypeConstraint("Tinput") @@ -1840,7 +1858,7 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize") .TypeConstraint("out_type"), NoOp); -// Register a templatized implementation of MklQuntizedConv2D. +// Register a templatized implementation of MklQuantizedConv2D. REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2D") .Device(DEVICE_CPU) @@ -2029,17 +2047,40 @@ REGISTER_KERNEL_BUILDER( .Device(DEVICE_CPU) .TypeConstraint("Tinput") .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DSumReluOp); + REGISTER_KERNEL_BUILDER( Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Device(DEVICE_CPU) .TypeConstraint("Tinput") .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") .TypeConstraint("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DSumReluOp); + +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DSumReluOp); + +REGISTER_KERNEL_BUILDER( + Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") + .Device(DEVICE_CPU) + .TypeConstraint("Tinput") + .TypeConstraint("Tfilter") + .TypeConstraint("Tbias") + .TypeConstraint("out_type") + .Label(mkl_op_registry::kMklQuantizedOpLabel), + MklQuantizedConv2DSumReluOp); #endif // INTEL_MKL_ML // Register 2D operations diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc new file mode 100644 index 00000000000..2e599d3d9f8 --- /dev/null +++ b/tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc @@ -0,0 +1,458 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef INTEL_MKL +#define EIGEN_USE_THREADS + +#include +#include +#include + +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/kernels/quantization_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +// TODO(bhavanis): Move ConvMklToTF to mkl_test_util.h as it is used by +// most unit tests. + +// Helper class for converting MKL tensors to TF tensors and comparing to +// expected values + +static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0}; +static const TensorShape dummy_shape({8}); + +class ConvMklToTF : public OpsTestBase { + public: + template + void ConvertMklToTF(DataType dtype, const Tensor& input, + const Tensor& input_metadata_tensor, Tensor& output) { + // Create an MKL to TF conversion node and execute it + TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf") + .Input(FakeInput(dtype)) // Input + .Input(FakeInput(DT_UINT8)) // MKL metadata tensor + .Attr("T", dtype) + .Attr("_kernel", "MklOp") + .Finalize(node_def())); + TF_EXPECT_OK(InitOp()); + AddInputFromArray(input.shape(), input.flat()); + AddInputFromArray(input_metadata_tensor.shape(), + input_metadata_tensor.flat()); + TF_ASSERT_OK(RunOpKernel()); + + output = *GetOutput(0); + } + void TestBody() {} +}; + +class QuantizedConv2DTest : public OpsTestBase { + protected: + void ConfigureQuantizedConv2D(const int& stride = 1) { + TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + } +}; + +// Output -> float +TEST_F(QuantizedConv2DTest, Small) { + const int stride = 1; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + + // Image -> uint8 + const float image_min = 0.0f; + const float image_max = 255.0f; + + // The image matrix is: + // | 1 | 2 | 3 | 4 | + // | 5 | 6 | 7 | 8 | + // | 9 | 10 | 11 | 12 | + Tensor image_float(DT_FLOAT, + {image_batch_count, image_height, image_width, depth}); + test::FillValues(&image_float, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + Tensor image_quantized = + FloatTensorToQuantized(image_float, image_min, image_max); + + const int filter_size = 3; + const int filter_count = 1; + + // Filter -> int8 with symmetric range + const float filter_min = -127.0f; + const float filter_max = 127.0f; + + // The filter matrix is: + // | 1 | 4 | 7 | + // | 2 | 5 | 8 | + // | 3 | 6 | 9 | + Tensor filter_float(DT_FLOAT, + {filter_size, filter_size, depth, filter_count}); + test::FillValues(&filter_float, {1, 4, 7, 2, 5, 8, 3, 6, 9}); + Tensor filter_quantized = + FloatTensorToQuantized(filter_float, filter_min, filter_max); + + AddInputFromArray(image_quantized.shape(), + image_quantized.flat()); + AddInputFromArray(filter_quantized.shape(), + filter_quantized.flat()); + AddInputFromArray(TensorShape({1}), {image_min}); + AddInputFromArray(TensorShape({1}), {image_max}); + AddInputFromArray(TensorShape({1}), {filter_min}); + AddInputFromArray(TensorShape({1}), {filter_max}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // We're sliding the 3x3 filter across the 3x4 image, with accesses outside + // the input set to zero because we're using the 'SAME' padding mode. + // The calculations behind the expected output are: + // (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105 + // (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150 + // (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183 + // (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95 + // (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235 + // (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312 + // (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357 + // (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178 + // (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187 + // (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234 + // (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261 + // (1*7)+(4*8)+(7*0)+(2*11)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121 + // This means we should end up with this matrix: + // | 105 | 150 | 183 | 95 | + // | 235 | 312 | 357 | 178 | + // | 187 | 234 | 261 | 121 | + + // Output -> float + const int expected_width = image_width; + const int expected_height = image_height; + Tensor expected_float( + DT_FLOAT, TensorShape({image_batch_count, expected_height, expected_width, + filter_count})); + test::FillValues(&expected_float, {105, 150, 183, 95, 235, 312, 357, + 178, 187, 234, 261, 121}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + const float output_min = GetOutput(1)->flat()(0); + const float output_max = GetOutput(2)->flat()(0); + Tensor output_float = + QuantizedTensorToFloat(output_quantized, output_min, output_max); + + test::ExpectTensorNear(expected_float, output_float, 1.0); +} + +// Output -> qint32 +TEST_F(QuantizedConv2DTest, Small32Bit) { + const int stride = 1; + ConfigureQuantizedConv2D(stride); + + // The illustrations and details regarding inputs and outputs + // are in TEST_F(QuantizedConv2DTest, Small) + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {10, 40, 70, 20, 50, 80, 30, 60, 90}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> qint32 + const int expected_width = image_width; + const int expected_height = image_height; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues( + &expected, {10500, 15000, 18300, 9500, 23500, 31200, 35700, 17800, 18700, + 23400, 26100, 12100}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +// Output -> qint32 +TEST_F(QuantizedConv2DTest, Small32BitWithPadding) { + const int stride = 1; + TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D") + .Input(FakeInput(DT_QUINT8)) // Input + .Input(FakeInput(DT_QINT8)) // Filter + .Input(FakeInput(DT_FLOAT)) // Min input + .Input(FakeInput(DT_FLOAT)) // Max input + .Input(FakeInput(DT_FLOAT)) // Min filter + .Input(FakeInput(DT_FLOAT)) // Max filter + // MKL metadata tensors // + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + /////////////////////////// + .Attr("Tinput", DataTypeToEnum::v()) + .Attr("Tfilter", DataTypeToEnum::v()) + .Attr("T", DataTypeToEnum::v()) + .Attr("out_type", DataTypeToEnum::v()) + .Attr("strides", {1, stride, stride, 1}) + .Attr("padding", "SAME") + .Attr("padding_list", {0, 0, 1, 1, 1, 1, 0, 0}) + .Attr("_kernel", "QuantizedMklOp") + .Finalize(node_def())); + TF_ASSERT_OK(InitOp()); + + // The illustrations and details regarding inputs and outputs + // are in TEST_F(QuantizedConv2DTest, Small) + const int depth = 1; + const int image_width = 4; + const int image_height = 3; + const int image_batch_count = 1; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {10, 40, 70, 20, 50, 80, 30, 60, 90}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> qint32 + const int expected_width = image_width; + const int expected_height = image_height; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues( + &expected, {10500, 15000, 18300, 9500, 23500, 31200, 35700, 17800, 18700, + 23400, 26100, 12100}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +// Output -> qint32 +TEST_F(QuantizedConv2DTest, OddPadding) { + const int stride = 2; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 4; + const int image_batch_count = 1; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> qint32 + const int expected_width = image_width / stride; + const int expected_height = image_height / stride; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues(&expected, {348, 252, 274, 175}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +// Output -> qint32 +TEST_F(QuantizedConv2DTest, OddPaddingBatch) { + const int stride = 2; + ConfigureQuantizedConv2D(stride); + + const int depth = 1; + const int image_width = 4; + const int image_height = 4; + const int image_batch_count = 3; + AddInputFromArray( + TensorShape({image_batch_count, image_height, image_width, depth}), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + + const int filter_size = 3; + const int filter_count = 1; + AddInputFromArray( + TensorShape({filter_size, filter_size, depth, filter_count}), + {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + // Image -> uint8 + AddInputFromArray(TensorShape({1}), {0.0f}); + AddInputFromArray(TensorShape({1}), {255.0f}); + + // Filter -> int8 with symmetric range + AddInputFromArray(TensorShape({1}), {-127.0f}); + AddInputFromArray(TensorShape({1}), {127.0f}); + + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + // Output -> qint32 + const int expected_width = image_width / stride; + const int expected_height = image_height / stride; + Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height, + expected_width, filter_count})); + test::FillValues( + &expected, {348, 252, 274, 175, 348, 252, 274, 175, 348, 252, 274, 175}); + + const Tensor& output = *GetOutput(0); + const Tensor& output_mkl_metadata = *GetOutput(3); + + ConvMklToTF conv_comp; + Tensor output_quantized; + conv_comp.ConvertMklToTF(DT_QINT32, output, output_mkl_metadata, + output_quantized); + + test::ExpectTensorEqual(expected, output_quantized); +} + +} // namespace tensorflow +#endif // INTEL_MKL diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index b23c3735665..0e6ad9162a5 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -206,6 +206,7 @@ REGISTER_OP("_MklQuantizedConv2D") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -251,6 +252,7 @@ REGISTER_OP("_MklQuantizedConv2DAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -296,6 +298,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBias") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -345,6 +348,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -387,6 +391,7 @@ REGISTER_OP("_MklQuantizedConv2DAndRelu") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -432,6 +437,7 @@ REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -477,6 +483,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -526,6 +533,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -574,6 +582,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -630,6 +639,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -688,6 +698,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Attr("is_filter_const: bool = true") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index ef7a65c0113..f290fb44565 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -2560,6 +2560,7 @@ REGISTER_OP("QuantizedConv2DAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2592,6 +2593,7 @@ REGISTER_OP("QuantizedConv2DWithBias") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2625,6 +2627,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2657,6 +2660,7 @@ REGISTER_OP("QuantizedConv2DAndRelu") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2687,6 +2691,7 @@ REGISTER_OP("QuantizedConv2DAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2719,6 +2724,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRelu") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2753,6 +2759,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2787,6 +2794,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2824,6 +2832,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2863,6 +2872,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") .Attr("strides: list(int)") .Attr(GetPaddingAttrString()) .Attr("dilations: list(int) = [1, 1, 1, 1]") + .Attr("padding_list: list(int) = []") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); ShapeHandle unused; @@ -2878,5 +2888,4 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize") return Status::OK(); }); - } // namespace tensorflow diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 36552c8ccdf..99c226ba9b4 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -2410,43 +2410,43 @@ tf_module { } member_method { name: "QuantizedConv2DAndRelu" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBias" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedInstanceNorm" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 36552c8ccdf..99c226ba9b4 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -2410,43 +2410,43 @@ tf_module { } member_method { name: "QuantizedConv2DAndRelu" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DAndRequantize" - argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBias" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndRelu" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedConv2DWithBiasSumAndReluAndRequantize" - argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None" + argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None" } member_method { name: "QuantizedInstanceNorm"