Merge pull request #25765 from Intel-tensorflow:bhavanis/quantized-pad-fusion-latest
PiperOrigin-RevId: 235995048
This commit is contained in:
commit
9a19de7a0f
@ -2388,7 +2388,8 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
|
||||
DataType Tinput, Tfilter, out_type;
|
||||
string padding;
|
||||
string data_format("NHWC");
|
||||
std::vector<int32> strides, dilations;
|
||||
std::vector<int32> strides, dilations, padding_list;
|
||||
bool has_padding_list = HasNodeAttr(orig_node->def(), "padding_list");
|
||||
|
||||
// Get all attributes from old node.
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Tinput", &Tinput));
|
||||
@ -2397,6 +2398,9 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "strides", &strides));
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
|
||||
if (has_padding_list) {
|
||||
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding_list", &padding_list));
|
||||
}
|
||||
|
||||
Node* filter_node = nullptr;
|
||||
orig_node->input_node(1, &filter_node);
|
||||
@ -2411,6 +2415,9 @@ void MklLayoutRewritePass::CopyAttrsQuantizedConv2D(const Node* orig_node,
|
||||
nb->Attr("dilations", dilations);
|
||||
nb->Attr("T", out_type); // added "T" for facilitating MklToTf conversion.
|
||||
nb->Attr("data_format", data_format);
|
||||
if (has_padding_list) {
|
||||
nb->Attr("padding_list", padding_list);
|
||||
}
|
||||
|
||||
// Requantization attr Tbias.
|
||||
DataType Tbias;
|
||||
|
@ -6327,6 +6327,29 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test_mkl(
|
||||
name = "mkl_quantized_conv_ops_test",
|
||||
size = "small",
|
||||
srcs = ["mkl_quantized_conv_ops_test.cc"],
|
||||
tags = ["nomsan"], # http://b/32242946
|
||||
deps = [
|
||||
":mkl_conv_op",
|
||||
":mkl_input_conversion_op",
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
":quantization_utils",
|
||||
":quantized_ops",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "quantize_op_test",
|
||||
size = "small",
|
||||
|
@ -466,7 +466,7 @@ class MklConvOp : public OpKernel {
|
||||
errors::InvalidArgument("filter must be 4-dimensional: ",
|
||||
filter.shape().DebugString()));
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
for (int i = 0; i < 3; ++i) {
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
|
||||
@ -860,6 +860,9 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
|
||||
if (context->HasAttr("padding_list")) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding_list", &padding_list_));
|
||||
}
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
@ -938,9 +941,19 @@ class MklConvOp : public OpKernel {
|
||||
dilations, strides;
|
||||
memory::dims dst_dims_tf_order, dst_dims_mkl_order;
|
||||
|
||||
// If pad with conv2d fusion is enabled
|
||||
if (fuse_pad_) {
|
||||
PadWithConvFusion(context, padding_left, padding_right);
|
||||
// For Quantized-Conv2D and Pad fusion, we get padding from the
|
||||
// `padding_list` attribute. Otherwise, we get it from one of the inputs.
|
||||
bool quantized_pad_enabled = false;
|
||||
for (auto const& padding_val : padding_list_) {
|
||||
if (padding_val) {
|
||||
quantized_pad_enabled = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (fuse_pad_ || quantized_pad_enabled) {
|
||||
PadWithConvFusion(context, padding_left, padding_right,
|
||||
quantized_pad_enabled);
|
||||
}
|
||||
|
||||
// Get shapes of input tensors in MKL-DNN order
|
||||
@ -951,7 +964,8 @@ class MklConvOp : public OpKernel {
|
||||
conv_utl.GetConvFwdSizesInMklOrder(
|
||||
src_tf_shape, filter_tf_shape, &src_dims, &filter_dims, &strides,
|
||||
&dilations, &dst_dims_tf_order, &dst_dims_mkl_order, &padding_left,
|
||||
&padding_right, fuse_pad_, is_depthwise);
|
||||
&padding_right, (fuse_pad_ || quantized_pad_enabled), is_depthwise);
|
||||
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
// Check for corner case - if there is nothing to compute, return.
|
||||
@ -1151,16 +1165,20 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
void PadWithConvFusion(OpKernelContext* context, memory::dims& padding_left,
|
||||
memory::dims& padding_right) {
|
||||
memory::dims& padding_right,
|
||||
bool quantized_pad_enabled) {
|
||||
const Tensor& paddings_tf = MklGetInput(context, input_index_pad_);
|
||||
OP_REQUIRES(context, paddings_tf.dims() == 2,
|
||||
errors::InvalidArgument("paddings must be 2-dimensional: ",
|
||||
paddings_tf.shape().DebugString()));
|
||||
|
||||
// Flatten tensor to get individual paddings.
|
||||
Tpadding* paddings = static_cast<Tpadding*>(
|
||||
const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data()));
|
||||
|
||||
Tpadding* paddings = nullptr;
|
||||
if (quantized_pad_enabled) {
|
||||
paddings = padding_list_.data();
|
||||
} else {
|
||||
OP_REQUIRES(context, paddings_tf.dims() == 2,
|
||||
errors::InvalidArgument("paddings must be 2-dimensional: ",
|
||||
paddings_tf.shape().DebugString()));
|
||||
// Flatten tensor to get individual paddings.
|
||||
paddings = static_cast<Tpadding*>(
|
||||
const_cast<Tpadding*>(paddings_tf.flat<Tpadding>().data()));
|
||||
}
|
||||
// If the data format is NHWC, indices 0, 1, 6 and 7 of paddings(_tf)
|
||||
// will be zero.
|
||||
// Example:
|
||||
@ -1186,8 +1204,7 @@ class MklConvOp : public OpKernel {
|
||||
pad_left = paddings[6];
|
||||
pad_right = paddings[7];
|
||||
}
|
||||
|
||||
// Create padding arrays for MKL DNN convolutions.
|
||||
// Create padding arrays for MKL-DNN convolutions.
|
||||
// MKL-DNN uses asymetric padding.
|
||||
padding_left = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
|
||||
padding_right = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
|
||||
@ -1264,6 +1281,7 @@ class MklConvOp : public OpKernel {
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
std::vector<int32> dilations_;
|
||||
std::vector<Tpadding> padding_list_;
|
||||
bool is_filter_const_;
|
||||
mutex mu_;
|
||||
Padding padding_;
|
||||
@ -1825,7 +1843,7 @@ class MklQuantizedConv2DSumReluOp
|
||||
};
|
||||
|
||||
// INT8 kernel registration
|
||||
// Register NoOp kernel for QunatizedConv2D for qint8 filter
|
||||
// Register NoOp kernel for QuantizedConv2D for qint8 filter
|
||||
REGISTER_KERNEL_BUILDER(Name("QuantizedConv2D")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("Tinput")
|
||||
@ -1840,7 +1858,7 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndRequantize")
|
||||
.TypeConstraint<qint8>("out_type"),
|
||||
NoOp);
|
||||
|
||||
// Register a templatized implementation of MklQuntizedConv2D.
|
||||
// Register a templatized implementation of MklQuantizedConv2D.
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_MklQuantizedConv2D")
|
||||
.Device(DEVICE_CPU)
|
||||
@ -2029,17 +2047,40 @@ REGISTER_KERNEL_BUILDER(
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("Tinput")
|
||||
.TypeConstraint<qint8>("Tfilter")
|
||||
.TypeConstraint<qint32>("Tbias")
|
||||
.TypeConstraint<quint8>("out_type")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, quint8, true>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("Tinput")
|
||||
.TypeConstraint<qint8>("Tfilter")
|
||||
.TypeConstraint<qint32>("Tbias")
|
||||
.TypeConstraint<quint8>("out_type")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, qint8, true>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("Tinput")
|
||||
.TypeConstraint<qint8>("Tfilter")
|
||||
.TypeConstraint<float>("Tbias")
|
||||
.TypeConstraint<quint8>("out_type")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, quint8, true>);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
|
||||
.Device(DEVICE_CPU)
|
||||
.TypeConstraint<quint8>("Tinput")
|
||||
.TypeConstraint<qint8>("Tfilter")
|
||||
.TypeConstraint<float>("Tbias")
|
||||
.TypeConstraint<quint8>("out_type")
|
||||
.Label(mkl_op_registry::kMklQuantizedOpLabel),
|
||||
MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, qint8, true>);
|
||||
#endif // INTEL_MKL_ML
|
||||
|
||||
// Register 2D operations
|
||||
|
458
tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc
Normal file
458
tensorflow/core/kernels/mkl_quantized_conv_ops_test.cc
Normal file
@ -0,0 +1,458 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/allocator.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
#include "tensorflow/core/framework/node_def_builder.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// TODO(bhavanis): Move ConvMklToTF to mkl_test_util.h as it is used by
|
||||
// most unit tests.
|
||||
|
||||
// Helper class for converting MKL tensors to TF tensors and comparing to
|
||||
// expected values
|
||||
|
||||
static const uint8 dummy_tensor[] = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
static const TensorShape dummy_shape({8});
|
||||
|
||||
class ConvMklToTF : public OpsTestBase {
|
||||
public:
|
||||
template <typename T>
|
||||
void ConvertMklToTF(DataType dtype, const Tensor& input,
|
||||
const Tensor& input_metadata_tensor, Tensor& output) {
|
||||
// Create an MKL to TF conversion node and execute it
|
||||
TF_EXPECT_OK(NodeDefBuilder("mkl_to_tf_op", "_MklToTf")
|
||||
.Input(FakeInput(dtype)) // Input
|
||||
.Input(FakeInput(DT_UINT8)) // MKL metadata tensor
|
||||
.Attr("T", dtype)
|
||||
.Attr("_kernel", "MklOp")
|
||||
.Finalize(node_def()));
|
||||
TF_EXPECT_OK(InitOp());
|
||||
AddInputFromArray<T>(input.shape(), input.flat<T>());
|
||||
AddInputFromArray<uint8>(input_metadata_tensor.shape(),
|
||||
input_metadata_tensor.flat<uint8>());
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
output = *GetOutput(0);
|
||||
}
|
||||
void TestBody() {}
|
||||
};
|
||||
|
||||
class QuantizedConv2DTest : public OpsTestBase {
|
||||
protected:
|
||||
void ConfigureQuantizedConv2D(const int& stride = 1) {
|
||||
TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D")
|
||||
.Input(FakeInput(DT_QUINT8)) // Input
|
||||
.Input(FakeInput(DT_QINT8)) // Filter
|
||||
.Input(FakeInput(DT_FLOAT)) // Min input
|
||||
.Input(FakeInput(DT_FLOAT)) // Max input
|
||||
.Input(FakeInput(DT_FLOAT)) // Min filter
|
||||
.Input(FakeInput(DT_FLOAT)) // Max filter
|
||||
// MKL metadata tensors //
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
///////////////////////////
|
||||
.Attr("Tinput", DataTypeToEnum<quint8>::v())
|
||||
.Attr("Tfilter", DataTypeToEnum<qint8>::v())
|
||||
.Attr("T", DataTypeToEnum<quint8>::v())
|
||||
.Attr("out_type", DataTypeToEnum<qint32>::v())
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("_kernel", "QuantizedMklOp")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
}
|
||||
};
|
||||
|
||||
// Output -> float
|
||||
TEST_F(QuantizedConv2DTest, Small) {
|
||||
const int stride = 1;
|
||||
ConfigureQuantizedConv2D(stride);
|
||||
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 3;
|
||||
const int image_batch_count = 1;
|
||||
|
||||
// Image -> uint8
|
||||
const float image_min = 0.0f;
|
||||
const float image_max = 255.0f;
|
||||
|
||||
// The image matrix is:
|
||||
// | 1 | 2 | 3 | 4 |
|
||||
// | 5 | 6 | 7 | 8 |
|
||||
// | 9 | 10 | 11 | 12 |
|
||||
Tensor image_float(DT_FLOAT,
|
||||
{image_batch_count, image_height, image_width, depth});
|
||||
test::FillValues<float>(&image_float,
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
|
||||
Tensor image_quantized =
|
||||
FloatTensorToQuantized<quint8>(image_float, image_min, image_max);
|
||||
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
|
||||
// Filter -> int8 with symmetric range
|
||||
const float filter_min = -127.0f;
|
||||
const float filter_max = 127.0f;
|
||||
|
||||
// The filter matrix is:
|
||||
// | 1 | 4 | 7 |
|
||||
// | 2 | 5 | 8 |
|
||||
// | 3 | 6 | 9 |
|
||||
Tensor filter_float(DT_FLOAT,
|
||||
{filter_size, filter_size, depth, filter_count});
|
||||
test::FillValues<float>(&filter_float, {1, 4, 7, 2, 5, 8, 3, 6, 9});
|
||||
Tensor filter_quantized =
|
||||
FloatTensorToQuantized<qint8>(filter_float, filter_min, filter_max);
|
||||
|
||||
AddInputFromArray<quint8>(image_quantized.shape(),
|
||||
image_quantized.flat<quint8>());
|
||||
AddInputFromArray<qint8>(filter_quantized.shape(),
|
||||
filter_quantized.flat<qint8>());
|
||||
AddInputFromArray<float>(TensorShape({1}), {image_min});
|
||||
AddInputFromArray<float>(TensorShape({1}), {image_max});
|
||||
AddInputFromArray<float>(TensorShape({1}), {filter_min});
|
||||
AddInputFromArray<float>(TensorShape({1}), {filter_max});
|
||||
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// We're sliding the 3x3 filter across the 3x4 image, with accesses outside
|
||||
// the input set to zero because we're using the 'SAME' padding mode.
|
||||
// The calculations behind the expected output are:
|
||||
// (1*0)+(4*0)+(7*0)+(2*0)+(5*1)+(8*2)+(3*0)+(6*5)+(9*6)=105
|
||||
// (1*0)+(4*0)+(7*0)+(2*1)+(5*2)+(8*3)+(3*5)+(6*6)+(9*7)=150
|
||||
// (1*0)+(4*0)+(7*0)+(2*2)+(5*3)+(8*4)+(3*6)+(6*7)+(9*8)=183
|
||||
// (1*0)+(4*0)+(7*0)+(2*3)+(5*4)+(8*0)+(3*7)+(6*8)+(9*0)=95
|
||||
// (1*0)+(4*1)+(7*2)+(2*0)+(5*5)+(8*6)+(3*0)+(6*9)+(9*10)=235
|
||||
// (1*1)+(4*2)+(7*3)+(2*5)+(5*6)+(8*7)+(3*9)+(6*10)+(9*11)=312
|
||||
// (1*2)+(4*3)+(7*4)+(2*6)+(5*7)+(8*8)+(3*10)+(6*11)+(9*12)=357
|
||||
// (1*3)+(4*4)+(7*0)+(2*7)+(5*8)+(8*0)+(3*11)+(6*12)+(9*0)=178
|
||||
// (1*0)+(4*5)+(7*6)+(2*0)+(5*9)+(8*10)+(3*0)+(6*0)+(9*0)=187
|
||||
// (1*5)+(4*6)+(7*7)+(2*9)+(5*10)+(8*11)+(3*0)+(6*0)+(9*0)=234
|
||||
// (1*6)+(4*7)+(7*8)+(2*10)+(5*11)+(8*12)+(3*0)+(6*0)+(9*0)=261
|
||||
// (1*7)+(4*8)+(7*0)+(2*11)+(5*12)+(8*0)+(3*0)+(6*0)+(9*0)=121
|
||||
// This means we should end up with this matrix:
|
||||
// | 105 | 150 | 183 | 95 |
|
||||
// | 235 | 312 | 357 | 178 |
|
||||
// | 187 | 234 | 261 | 121 |
|
||||
|
||||
// Output -> float
|
||||
const int expected_width = image_width;
|
||||
const int expected_height = image_height;
|
||||
Tensor expected_float(
|
||||
DT_FLOAT, TensorShape({image_batch_count, expected_height, expected_width,
|
||||
filter_count}));
|
||||
test::FillValues<float>(&expected_float, {105, 150, 183, 95, 235, 312, 357,
|
||||
178, 187, 234, 261, 121});
|
||||
|
||||
const Tensor& output = *GetOutput(0);
|
||||
const Tensor& output_mkl_metadata = *GetOutput(3);
|
||||
|
||||
ConvMklToTF conv_comp;
|
||||
Tensor output_quantized;
|
||||
conv_comp.ConvertMklToTF<qint32>(DT_QINT32, output, output_mkl_metadata,
|
||||
output_quantized);
|
||||
|
||||
const float output_min = GetOutput(1)->flat<float>()(0);
|
||||
const float output_max = GetOutput(2)->flat<float>()(0);
|
||||
Tensor output_float =
|
||||
QuantizedTensorToFloat<qint32>(output_quantized, output_min, output_max);
|
||||
|
||||
test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
|
||||
}
|
||||
|
||||
// Output -> qint32
|
||||
TEST_F(QuantizedConv2DTest, Small32Bit) {
|
||||
const int stride = 1;
|
||||
ConfigureQuantizedConv2D(stride);
|
||||
|
||||
// The illustrations and details regarding inputs and outputs
|
||||
// are in TEST_F(QuantizedConv2DTest, Small)
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 3;
|
||||
const int image_batch_count = 1;
|
||||
AddInputFromArray<quint8>(
|
||||
TensorShape({image_batch_count, image_height, image_width, depth}),
|
||||
{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120});
|
||||
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
AddInputFromArray<qint8>(
|
||||
TensorShape({filter_size, filter_size, depth, filter_count}),
|
||||
{10, 40, 70, 20, 50, 80, 30, 60, 90});
|
||||
|
||||
// Image -> uint8
|
||||
AddInputFromArray<float>(TensorShape({1}), {0.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
||||
|
||||
// Filter -> int8 with symmetric range
|
||||
AddInputFromArray<float>(TensorShape({1}), {-127.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {127.0f});
|
||||
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Output -> qint32
|
||||
const int expected_width = image_width;
|
||||
const int expected_height = image_height;
|
||||
Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height,
|
||||
expected_width, filter_count}));
|
||||
test::FillValues<qint32>(
|
||||
&expected, {10500, 15000, 18300, 9500, 23500, 31200, 35700, 17800, 18700,
|
||||
23400, 26100, 12100});
|
||||
|
||||
const Tensor& output = *GetOutput(0);
|
||||
const Tensor& output_mkl_metadata = *GetOutput(3);
|
||||
|
||||
ConvMklToTF conv_comp;
|
||||
Tensor output_quantized;
|
||||
conv_comp.ConvertMklToTF<qint32>(DT_QINT32, output, output_mkl_metadata,
|
||||
output_quantized);
|
||||
|
||||
test::ExpectTensorEqual<qint32>(expected, output_quantized);
|
||||
}
|
||||
|
||||
// Output -> qint32
|
||||
TEST_F(QuantizedConv2DTest, Small32BitWithPadding) {
|
||||
const int stride = 1;
|
||||
TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D")
|
||||
.Input(FakeInput(DT_QUINT8)) // Input
|
||||
.Input(FakeInput(DT_QINT8)) // Filter
|
||||
.Input(FakeInput(DT_FLOAT)) // Min input
|
||||
.Input(FakeInput(DT_FLOAT)) // Max input
|
||||
.Input(FakeInput(DT_FLOAT)) // Min filter
|
||||
.Input(FakeInput(DT_FLOAT)) // Max filter
|
||||
// MKL metadata tensors //
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
.Input(FakeInput(DT_UINT8))
|
||||
///////////////////////////
|
||||
.Attr("Tinput", DataTypeToEnum<quint8>::v())
|
||||
.Attr("Tfilter", DataTypeToEnum<qint8>::v())
|
||||
.Attr("T", DataTypeToEnum<quint8>::v())
|
||||
.Attr("out_type", DataTypeToEnum<qint32>::v())
|
||||
.Attr("strides", {1, stride, stride, 1})
|
||||
.Attr("padding", "SAME")
|
||||
.Attr("padding_list", {0, 0, 1, 1, 1, 1, 0, 0})
|
||||
.Attr("_kernel", "QuantizedMklOp")
|
||||
.Finalize(node_def()));
|
||||
TF_ASSERT_OK(InitOp());
|
||||
|
||||
// The illustrations and details regarding inputs and outputs
|
||||
// are in TEST_F(QuantizedConv2DTest, Small)
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 3;
|
||||
const int image_batch_count = 1;
|
||||
AddInputFromArray<quint8>(
|
||||
TensorShape({image_batch_count, image_height, image_width, depth}),
|
||||
{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120});
|
||||
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
AddInputFromArray<qint8>(
|
||||
TensorShape({filter_size, filter_size, depth, filter_count}),
|
||||
{10, 40, 70, 20, 50, 80, 30, 60, 90});
|
||||
|
||||
// Image -> uint8
|
||||
AddInputFromArray<float>(TensorShape({1}), {0.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
||||
|
||||
// Filter -> int8 with symmetric range
|
||||
AddInputFromArray<float>(TensorShape({1}), {-127.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {127.0f});
|
||||
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Output -> qint32
|
||||
const int expected_width = image_width;
|
||||
const int expected_height = image_height;
|
||||
Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height,
|
||||
expected_width, filter_count}));
|
||||
test::FillValues<qint32>(
|
||||
&expected, {10500, 15000, 18300, 9500, 23500, 31200, 35700, 17800, 18700,
|
||||
23400, 26100, 12100});
|
||||
|
||||
const Tensor& output = *GetOutput(0);
|
||||
const Tensor& output_mkl_metadata = *GetOutput(3);
|
||||
|
||||
ConvMklToTF conv_comp;
|
||||
Tensor output_quantized;
|
||||
conv_comp.ConvertMklToTF<qint32>(DT_QINT32, output, output_mkl_metadata,
|
||||
output_quantized);
|
||||
|
||||
test::ExpectTensorEqual<qint32>(expected, output_quantized);
|
||||
}
|
||||
|
||||
// Output -> qint32
|
||||
TEST_F(QuantizedConv2DTest, OddPadding) {
|
||||
const int stride = 2;
|
||||
ConfigureQuantizedConv2D(stride);
|
||||
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 4;
|
||||
const int image_batch_count = 1;
|
||||
AddInputFromArray<quint8>(
|
||||
TensorShape({image_batch_count, image_height, image_width, depth}),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
AddInputFromArray<qint8>(
|
||||
TensorShape({filter_size, filter_size, depth, filter_count}),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
|
||||
// Image -> uint8
|
||||
AddInputFromArray<float>(TensorShape({1}), {0.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
||||
|
||||
// Filter -> int8 with symmetric range
|
||||
AddInputFromArray<float>(TensorShape({1}), {-127.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {127.0f});
|
||||
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Output -> qint32
|
||||
const int expected_width = image_width / stride;
|
||||
const int expected_height = image_height / stride;
|
||||
Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height,
|
||||
expected_width, filter_count}));
|
||||
test::FillValues<qint32>(&expected, {348, 252, 274, 175});
|
||||
|
||||
const Tensor& output = *GetOutput(0);
|
||||
const Tensor& output_mkl_metadata = *GetOutput(3);
|
||||
|
||||
ConvMklToTF conv_comp;
|
||||
Tensor output_quantized;
|
||||
conv_comp.ConvertMklToTF<qint32>(DT_QINT32, output, output_mkl_metadata,
|
||||
output_quantized);
|
||||
|
||||
test::ExpectTensorEqual<qint32>(expected, output_quantized);
|
||||
}
|
||||
|
||||
// Output -> qint32
|
||||
TEST_F(QuantizedConv2DTest, OddPaddingBatch) {
|
||||
const int stride = 2;
|
||||
ConfigureQuantizedConv2D(stride);
|
||||
|
||||
const int depth = 1;
|
||||
const int image_width = 4;
|
||||
const int image_height = 4;
|
||||
const int image_batch_count = 3;
|
||||
AddInputFromArray<quint8>(
|
||||
TensorShape({image_batch_count, image_height, image_width, depth}),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
|
||||
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
|
||||
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 1;
|
||||
AddInputFromArray<qint8>(
|
||||
TensorShape({filter_size, filter_size, depth, filter_count}),
|
||||
{1, 2, 3, 4, 5, 6, 7, 8, 9});
|
||||
|
||||
// Image -> uint8
|
||||
AddInputFromArray<float>(TensorShape({1}), {0.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {255.0f});
|
||||
|
||||
// Filter -> int8 with symmetric range
|
||||
AddInputFromArray<float>(TensorShape({1}), {-127.0f});
|
||||
AddInputFromArray<float>(TensorShape({1}), {127.0f});
|
||||
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
|
||||
|
||||
TF_ASSERT_OK(RunOpKernel());
|
||||
|
||||
// Output -> qint32
|
||||
const int expected_width = image_width / stride;
|
||||
const int expected_height = image_height / stride;
|
||||
Tensor expected(DT_QINT32, TensorShape({image_batch_count, expected_height,
|
||||
expected_width, filter_count}));
|
||||
test::FillValues<qint32>(
|
||||
&expected, {348, 252, 274, 175, 348, 252, 274, 175, 348, 252, 274, 175});
|
||||
|
||||
const Tensor& output = *GetOutput(0);
|
||||
const Tensor& output_mkl_metadata = *GetOutput(3);
|
||||
|
||||
ConvMklToTF conv_comp;
|
||||
Tensor output_quantized;
|
||||
conv_comp.ConvertMklToTF<qint32>(DT_QINT32, output, output_mkl_metadata,
|
||||
output_quantized);
|
||||
|
||||
test::ExpectTensorEqual<qint32>(expected, output_quantized);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
@ -206,6 +206,7 @@ REGISTER_OP("_MklQuantizedConv2D")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -251,6 +252,7 @@ REGISTER_OP("_MklQuantizedConv2DAndRequantize")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -296,6 +298,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBias")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -345,6 +348,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRequantize")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -387,6 +391,7 @@ REGISTER_OP("_MklQuantizedConv2DAndRelu")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -432,6 +437,7 @@ REGISTER_OP("_MklQuantizedConv2DAndReluAndRequantize")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -477,6 +483,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndRelu")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -526,6 +533,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -574,6 +582,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndRelu")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -630,6 +639,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -688,6 +698,7 @@ REGISTER_OP("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
|
||||
.Attr("is_filter_const: bool = true")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
|
@ -2560,6 +2560,7 @@ REGISTER_OP("QuantizedConv2DAndRequantize")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2592,6 +2593,7 @@ REGISTER_OP("QuantizedConv2DWithBias")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2625,6 +2627,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRequantize")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2657,6 +2660,7 @@ REGISTER_OP("QuantizedConv2DAndRelu")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2687,6 +2691,7 @@ REGISTER_OP("QuantizedConv2DAndReluAndRequantize")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2719,6 +2724,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndRelu")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2753,6 +2759,7 @@ REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2787,6 +2794,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2824,6 +2832,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2863,6 +2872,7 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
|
||||
.Attr("strides: list(int)")
|
||||
.Attr(GetPaddingAttrString())
|
||||
.Attr("dilations: list(int) = [1, 1, 1, 1]")
|
||||
.Attr("padding_list: list(int) = []")
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
|
||||
ShapeHandle unused;
|
||||
@ -2878,5 +2888,4 @@ REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -2410,43 +2410,43 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DAndRelu"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBias"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasAndRelu"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasSumAndRelu"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasSumAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedInstanceNorm"
|
||||
|
@ -2410,43 +2410,43 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DAndRelu"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBias"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasAndRelu"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasSignedSumAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasSumAndRelu"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedConv2DWithBiasSumAndReluAndRequantize"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\'], varargs=None, keywords=None, defaults=None"
|
||||
argspec: "args=[\'input\', \'filter\', \'bias\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'min_freezed_output\', \'max_freezed_output\', \'summand\', \'min_summand\', \'max_summand\', \'out_type\', \'strides\', \'padding\', \'dilations\', \'padding_list\'], varargs=None, keywords=None, defaults=None"
|
||||
}
|
||||
member_method {
|
||||
name: "QuantizedInstanceNorm"
|
||||
|
Loading…
Reference in New Issue
Block a user