support for {int8,int8} convolutions and fusions

This commit is contained in:
srinivasan.narayanamoorthy 2019-10-07 14:20:14 -07:00
parent c85e67d2d5
commit e4e8f0c0ff
3 changed files with 345 additions and 139 deletions

View File

@ -152,7 +152,7 @@ static inline bool IsMklLayoutDependentOp(const string& op_name,
// Restrict quantized ops to QUINT8 and QINT8 for now // Restrict quantized ops to QUINT8 and QINT8 for now
if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) { if (kernel.find(kMklQuantizedOpLabelPattern) != string::npos) {
return (Tinput == DT_QUINT8 && Tfilter == DT_QINT8); return (Tfilter == DT_QINT8);
} }
return false; return false;
} }

View File

@ -24,8 +24,8 @@ limitations under the License.
#include <map> #include <map>
#include <vector> #include <vector>
#include "mkldnn.hpp"
#include "absl/strings/str_join.h" #include "absl/strings/str_join.h"
#include "mkldnn.hpp"
#include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -570,17 +570,15 @@ class MklConvOp : public OpKernel {
OP_REQUIRES(context, dilations_.size() == 5, OP_REQUIRES(context, dilations_.size() == 5,
errors::InvalidArgument("Dilation rates field must " errors::InvalidArgument("Dilation rates field must "
"specify 5 dimensions")); "specify 5 dimensions"));
OP_REQUIRES(context, OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
(GetTensorDim(dilations_, data_format_, 'N') == 1 && GetTensorDim(dilations_, data_format_, 'C') == 1),
GetTensorDim(dilations_, data_format_, 'C') == 1),
errors::InvalidArgument( errors::InvalidArgument(
"Current implementation does not yet support " "Current implementation does not yet support "
"dilations rates in the batch and depth dimensions.")); "dilations rates in the batch and depth dimensions."));
OP_REQUIRES( OP_REQUIRES(
context, context, (GetTensorDim(dilations_, data_format_, '0') > 0 &&
(GetTensorDim(dilations_, data_format_, '0') > 0 && GetTensorDim(dilations_, data_format_, '1') > 0 &&
GetTensorDim(dilations_, data_format_, '1') > 0 && GetTensorDim(dilations_, data_format_, '2') > 0),
GetTensorDim(dilations_, data_format_, '2') > 0),
errors::InvalidArgument("Dilated rates should be larger than 0.")); errors::InvalidArgument("Dilated rates should be larger than 0."));
} }
} }
@ -590,7 +588,6 @@ class MklConvOp : public OpKernel {
// Input tensors // Input tensors
const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src); const Tensor& src_tensor = MklGetInput(context, kInputIndex_Src);
const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter); const Tensor& filter_tensor = MklGetInput(context, kInputIndex_Filter);
MklDnnShape src_mkl_shape, filter_mkl_shape; MklDnnShape src_mkl_shape, filter_mkl_shape;
GetMklShape(context, kInputIndex_Src, &src_mkl_shape, eager_mode); GetMklShape(context, kInputIndex_Src, &src_mkl_shape, eager_mode);
GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, eager_mode); GetMklShape(context, kInputIndex_Filter, &filter_mkl_shape, eager_mode);
@ -786,7 +783,6 @@ class MklConvOp : public OpKernel {
src_data = static_cast<Tinput*>( src_data = static_cast<Tinput*>(
const_cast<Tinput*>(src_tensor.flat<Tinput>().data())); const_cast<Tinput*>(src_tensor.flat<Tinput>().data()));
} }
Tfilter* filter_data = nullptr; Tfilter* filter_data = nullptr;
if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) { if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) {
bool is_filter_cached = false; bool is_filter_cached = false;
@ -859,7 +855,6 @@ class MklConvOp : public OpKernel {
cpu_engine_); cpu_engine_);
} }
} }
// Delete primitive since it is not cached. // Delete primitive since it is not cached.
if (do_not_cache) delete conv_fwd; if (do_not_cache) delete conv_fwd;
} catch (mkldnn::error& e) { } catch (mkldnn::error& e) {
@ -1417,10 +1412,10 @@ class MklFusedConvOp
// We create new class for each version of Quantized Convolution and inherit // We create new class for each version of Quantized Convolution and inherit
// from the FP32 version of the base class // from the FP32 version of the base class
template <typename Device, typename Tbias, typename Toutput, template <typename Device, typename Tinput, typename Tbias, typename Toutput,
typename Ttemp_output, bool bias_enabled, bool is_depthwise> typename Ttemp_output, bool bias_enabled, bool is_depthwise>
class MklQuantizedConv2DOp class MklQuantizedConv2DOp
: public MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, : public MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output,
int32, bias_enabled, false, is_depthwise, false> { int32, bias_enabled, false, is_depthwise, false> {
public: public:
virtual ~MklQuantizedConv2DOp() { virtual ~MklQuantizedConv2DOp() {
@ -1436,7 +1431,7 @@ class MklQuantizedConv2DOp
} }
explicit MklQuantizedConv2DOp(OpKernelConstruction* context) explicit MklQuantizedConv2DOp(OpKernelConstruction* context)
: MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, : MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
bias_enabled, false, is_depthwise, false>(context) { bias_enabled, false, is_depthwise, false>(context) {
bool is_filter_const; bool is_filter_const;
OP_REQUIRES_OK(context, OP_REQUIRES_OK(context,
@ -1453,7 +1448,7 @@ class MklQuantizedConv2DOp
void Compute(OpKernelContext* context) override { void Compute(OpKernelContext* context) override {
// Compute int32 output tensor // Compute int32 output tensor
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
bias_enabled, false, is_depthwise, false>::Compute(context); bias_enabled, false, is_depthwise, false>::Compute(context);
// Compute additional outputs: min/max scalars. // Compute additional outputs: min/max scalars.
@ -1488,7 +1483,7 @@ class MklQuantizedConv2DOp
if (min_filter.dims() == 0) { if (min_filter.dims() == 0) {
float min_output_value; float min_output_value;
float max_output_value; float max_output_value;
MklQuantizationRangeForMultiplication<quint8, qint8, qint32>( MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>(
min_input, max_input, min_filter.flat<float>()(0), min_input, max_input, min_filter.flat<float>()(0),
max_filter.flat<float>()(0), &min_output_value, &max_output_value); max_filter.flat<float>()(0), &min_output_value, &max_output_value);
AllocateOutputSetMklShape(context, 1, &output_min, {}, AllocateOutputSetMklShape(context, 1, &output_min, {},
@ -1505,7 +1500,7 @@ class MklQuantizedConv2DOp
AllocateOutputSetMklShape(context, 2, &output_max, AllocateOutputSetMklShape(context, 2, &output_max,
{static_cast<ptrdiff_t>(depth)}, {static_cast<ptrdiff_t>(depth)},
output_max_mkl_shape); output_max_mkl_shape);
MklQuantizationRangeForMultiplication<quint8, qint8, qint32>( MklQuantizationRangeForMultiplication<Tinput, qint8, qint32>(
min_input, max_input, min_filter, max_filter, &output_min, min_input, max_input, min_filter, max_filter, &output_min,
&output_max); &output_max);
} }
@ -1515,10 +1510,9 @@ class MklQuantizedConv2DOp
protected: protected:
void ExtendConvFwdParams(OpKernelContext* context, void ExtendConvFwdParams(OpKernelContext* context,
MklConvFwdParams& params) override { MklConvFwdParams& params) override {
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
bias_enabled, false, is_depthwise, bias_enabled, false, is_depthwise,
false>::ExtendConvFwdParams(context, params); false>::ExtendConvFwdParams(context, params);
// When the output type is quint8, the output data id requantized // When the output type is quint8, the output data id requantized
// into quint8. A post_op "output_scale" is added to do the conversion. // into quint8. A post_op "output_scale" is added to do the conversion.
if (std::is_same<Toutput, quint8>::value || if (std::is_same<Toutput, quint8>::value ||
@ -1540,22 +1534,26 @@ class MklQuantizedConv2DOp
const float max_freezed_output = const float max_freezed_output =
context->input(7 + bias_index_offset).flat<float>()(0); context->input(7 + bias_index_offset).flat<float>()(0);
float factor = std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f; float int_output_limit =
std::is_same<Toutput, quint8>::value ? 255.0f : 127.0f;
size_t depth = min_filter_vector.NumElements(); size_t depth = min_filter_vector.NumElements();
const float* min_filter = min_filter_vector.flat<float>().data(); const float* min_filter = min_filter_vector.flat<float>().data();
const float* max_filter = max_filter_vector.flat<float>().data(); const float* max_filter = max_filter_vector.flat<float>().data();
std::vector<float> scales(depth); std::vector<float> scales(depth);
float input_range = std::max(std::abs(min_input), std::abs(max_input)); float float_input_range =
float output_range = std::max(std::abs(min_input), std::abs(max_input));
float float_output_range =
std::max(std::abs(min_freezed_output), std::abs(max_freezed_output)); std::max(std::abs(min_freezed_output), std::abs(max_freezed_output));
const float int_const_scale_limit =
(std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
for (size_t i = 0; i < depth; ++i) { for (size_t i = 0; i < depth; ++i) {
// For simplicity and symmetry, we set filter range to be outer // For simplicity and symmetry, we set filter range to be outer
// bounds of min_filter and max_filter. // bounds of min_filter and max_filter.
float filter_range = float float_filter_range =
std::max(std::abs(min_filter[i]), std::abs(max_filter[i])); std::max(std::abs(min_filter[i]), std::abs(max_filter[i]));
// To understand the scaling, please see mkl_requantize_ops_test. // To understand the scaling, please see mkl_requantize_ops_test.
scales[i] = factor * input_range * filter_range / scales[i] = int_output_limit * float_input_range * float_filter_range /
(255.0f * 127.0f * output_range); (int_const_scale_limit * float_output_range);
} }
params.post_op_params.push_back( params.post_op_params.push_back(
{"output_scale", ALGORITHM_UNDEF, scales}); {"output_scale", ALGORITHM_UNDEF, scales});
@ -1584,6 +1582,8 @@ class MklQuantizedConv2DOp
const_cast<Tbias*>(bias_tensor.flat<Tbias>().data())); const_cast<Tbias*>(bias_tensor.flat<Tbias>().data()));
} }
const float int_const_scale_limit =
(std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
// Re-scale bias if either of following 2 conditions are met: // Re-scale bias if either of following 2 conditions are met:
// 1. Bias is not const; // 1. Bias is not const;
// 2. Bias is const, but bias cache is empty (first iteration). // 2. Bias is const, but bias cache is empty (first iteration).
@ -1597,7 +1597,7 @@ class MklQuantizedConv2DOp
std::vector<float> scales(depth); std::vector<float> scales(depth);
for (size_t i = 0; i < depth; ++i) { for (size_t i = 0; i < depth; ++i) {
scales[i] = scales[i] =
255.0 * 127.0 / int_const_scale_limit /
(std::max(std::abs(max_input), std::abs(min_input)) * (std::max(std::abs(max_input), std::abs(min_input)) *
std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
} }
@ -1700,32 +1700,33 @@ class MklQuantizedConv2DOp
} }
}; };
template <typename Device, typename Tbias, typename Toutput, template <typename Device, typename Tinput, typename Tbias, typename Toutput,
typename Ttemp_output, bool bias_enabled, bool is_depthwise> typename Ttemp_output, bool bias_enabled, bool is_depthwise>
class MklQuantizedConv2DReluOp class MklQuantizedConv2DReluOp
: public MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
bias_enabled, is_depthwise> { bias_enabled, is_depthwise> {
public: public:
virtual ~MklQuantizedConv2DReluOp() {} virtual ~MklQuantizedConv2DReluOp() {}
explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context) explicit MklQuantizedConv2DReluOp(OpKernelConstruction* context)
: MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled, : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
is_depthwise>(context) {} bias_enabled, is_depthwise>(context) {}
protected: protected:
void ExtendConvFwdParams(OpKernelContext* context, void ExtendConvFwdParams(OpKernelContext* context,
MklConvFwdParams& params) override { MklConvFwdParams& params) override {
MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled, MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
bias_enabled,
is_depthwise>::ExtendConvFwdParams(context, params); is_depthwise>::ExtendConvFwdParams(context, params);
params.post_op_params.push_back( params.post_op_params.push_back(
{"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}}); {"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}});
} }
}; };
template <typename Device, typename Tbias, typename Toutput, template <typename Device, typename Tinput, typename Tbias, typename Toutput,
typename Ttemp_output, bool bias_enabled, bool is_depthwise> typename Ttemp_output, bool bias_enabled, bool is_depthwise>
class MklQuantizedConv2DSumReluOp class MklQuantizedConv2DSumReluOp
: public MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, : public MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
bias_enabled, is_depthwise> { bias_enabled, is_depthwise> {
public: public:
virtual ~MklQuantizedConv2DSumReluOp() { virtual ~MklQuantizedConv2DSumReluOp() {
@ -1741,13 +1742,14 @@ class MklQuantizedConv2DSumReluOp
} }
explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context) explicit MklQuantizedConv2DSumReluOp(OpKernelConstruction* context)
: MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled, : MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
is_depthwise>(context) {} bias_enabled, is_depthwise>(context) {}
protected: protected:
void ExtendConvFwdParams(OpKernelContext* context, void ExtendConvFwdParams(OpKernelContext* context,
MklConvFwdParams& params) override { MklConvFwdParams& params) override {
MklQuantizedConv2DOp<Device, Tbias, Toutput, Ttemp_output, bias_enabled, MklQuantizedConv2DOp<Device, Tinput, Tbias, Toutput, Ttemp_output,
bias_enabled,
is_depthwise>::ExtendConvFwdParams(context, params); is_depthwise>::ExtendConvFwdParams(context, params);
// Calculate the scale (beta in mkldnn api term) for sum // Calculate the scale (beta in mkldnn api term) for sum
if (std::is_same<Toutput, quint8>::value) { if (std::is_same<Toutput, quint8>::value) {
@ -1821,7 +1823,7 @@ class MklQuantizedConv2DSumReluOp
*output_tensor = const_cast<Tensor*>(&summand); *output_tensor = const_cast<Tensor*>(&summand);
return; return;
} }
MklConvOp<Device, quint8, qint8, Tbias, Toutput, Ttemp_output, int32, MklConvOp<Device, Tinput, qint8, Tbias, Toutput, Ttemp_output, int32,
bias_enabled, false, false, bias_enabled, false, false,
false>::AllocateOutputTensor(context, conv_prim_desc, false>::AllocateOutputTensor(context, conv_prim_desc,
output_dims_mkl_order, output_dims_mkl_order,
@ -1844,6 +1846,8 @@ class MklQuantizedConv2DSumReluOp
const float* min_filter = min_filter_vector.flat<float>().data(); const float* min_filter = min_filter_vector.flat<float>().data();
const float* max_filter = max_filter_vector.flat<float>().data(); const float* max_filter = max_filter_vector.flat<float>().data();
const float int_const_scale_limit =
(std::is_same<Tinput, quint8>::value) ? 255.0 * 127.0 : 127.0 * 127.0;
size_t depth = min_filter_vector.NumElements(); size_t depth = min_filter_vector.NumElements();
std::vector<float> scales(depth); std::vector<float> scales(depth);
for (size_t i = 0; i < depth; ++i) { for (size_t i = 0; i < depth; ++i) {
@ -1851,7 +1855,7 @@ class MklQuantizedConv2DSumReluOp
// done regularly. A Cleaner design to address all mapping in one // done regularly. A Cleaner design to address all mapping in one
// function needs to be implemented in future which also supports other // function needs to be implemented in future which also supports other
// quantized type mapping in future. // quantized type mapping in future.
scales[i] = 255.0 * 127.0 / scales[i] = int_const_scale_limit /
(std::max(std::abs(max_input), std::abs(min_input)) * (std::max(std::abs(max_input), std::abs(min_input)) *
std::max(std::abs(max_filter[i]), std::abs(min_filter[i]))); std::max(std::abs(max_filter[i]), std::abs(min_filter[i])));
} }
@ -1911,33 +1915,41 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DPerChannel")
.TypeConstraint<qint32>("out_type"), .TypeConstraint<qint32>("out_type"),
NoOp); NoOp);
// Register a templatized implementation of MklQuantizedConv2DPerChannel. // Register a templatized implementation of MklQuantizedConv2DPerChannel.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DPerChannel")
Name("_MklQuantizedConv2DPerChannel") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, false, false>); qint32, false, false>);
// Register a templatized implementation of MklQuantizedConv2D. // Register a templatized implementation of MklQuantizedConv2D.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2D")
Name("_MklQuantizedConv2D") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, false, false>); qint32, false, false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2D")
Name("_MklQuantizedConv2DAndRequantize") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<qint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint8>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DOp<CPUDevice, qint8, float, qint32,
MklQuantizedConv2DOp<CPUDevice, qint32, qint8, qint8, false, false>); qint32, false, false>);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<quint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, quint8, qint32, qint8,
qint8, false, false>);
// Register NoOp kernel for QuantizedConv2DWithBias to get a python interface. // Register NoOp kernel for QuantizedConv2DWithBias to get a python interface.
// This kernel will be replaced by an MKL kernel during graph // This kernel will be replaced by an MKL kernel during graph
@ -1956,15 +1968,28 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize")
.TypeConstraint<qint8>("out_type"), .TypeConstraint<qint8>("out_type"),
NoOp); NoOp);
REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBias")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("out_type"),
NoOp);
REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("out_type"),
NoOp);
// Register a templatized implementation MklQuantizedConv2DWithBias. // Register a templatized implementation MklQuantizedConv2DWithBias.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBias")
Name("_MklQuantizedConv2DWithBias") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, true, false>); qint32, true, false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasAndRequantize") Name("_MklQuantizedConv2DWithBiasAndRequantize")
@ -1974,7 +1999,7 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<qint32>("Tbias") .TypeConstraint<qint32>("Tbias")
.TypeConstraint<qint8>("out_type") .TypeConstraint<qint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, qint32, qint8, qint8, true, false>); MklQuantizedConv2DOp<CPUDevice, quint8, qint32, qint8, qint8, true, false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasAndRequantize") Name("_MklQuantizedConv2DWithBiasAndRequantize")
@ -1984,7 +2009,36 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<float>("Tbias") .TypeConstraint<float>("Tbias")
.TypeConstraint<qint8>("out_type") .TypeConstraint<qint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, float, qint8, qint8, true, false>); MklQuantizedConv2DOp<CPUDevice, quint8, float, qint8, qint8, true, false>);
REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBias")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, qint8, float, qint32, qint32, true, false>);
REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("Tbias")
.TypeConstraint<qint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, qint8, qint32, qint8, qint8, true, false>);
REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<float>("Tbias")
.TypeConstraint<qint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, qint8, float, qint8, qint8, true, false>);
// Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface. // Register NoOp kernel for QuantizedConv2DAndRelu to get a python interface.
// This kernel will be replaced by an MKL kernel during graph-optimization pass. // This kernel will be replaced by an MKL kernel during graph-optimization pass.
@ -2003,23 +2057,23 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DAndReluAndRequantize")
NoOp); NoOp);
// Register a templatized implementation of MklQuantizedConv2DAndRelu. // Register a templatized implementation of MklQuantizedConv2DAndRelu.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndRelu")
Name("_MklQuantizedConv2DAndRelu") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, false, false>); qint32, qint32, false, false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DAndReluAndRequantize")
Name("_MklQuantizedConv2DAndReluAndRequantize") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<quint8>("out_type")
.TypeConstraint<quint8>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32,
MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, false, false>); quint8, quint8, false, false>);
// Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python // Register NoOp kernel for QuantizedConv2DWithBiasAndRelu to get a python
// interface. // interface.
@ -2031,6 +2085,13 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu")
.TypeConstraint<qint32>("out_type"), .TypeConstraint<qint32>("out_type"),
NoOp); NoOp);
REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndRelu")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("out_type"),
NoOp);
// Register NoOp kernel for QuantizedConv2DWithBiasAndReluAndRequantize // Register NoOp kernel for QuantizedConv2DWithBiasAndReluAndRequantize
// to get a python interface. // to get a python interface.
// This kernel will be replaced by an MKL kernel during graph-optimization pass. // This kernel will be replaced by an MKL kernel during graph-optimization pass.
@ -2041,37 +2102,71 @@ REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize")
.TypeConstraint<quint8>("out_type"), .TypeConstraint<quint8>("out_type"),
NoOp); NoOp);
REGISTER_KERNEL_BUILDER(Name("QuantizedConv2DWithBiasAndReluAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<quint8>("out_type"),
NoOp);
// Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu. // Register a templatized implementation of MklQuantizedConv2DWithBiasAndRelu.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndRelu")
Name("_MklQuantizedConv2DWithBiasAndRelu") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, true, false>); qint32, qint32, true, false>);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndRelu")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DReluOp<CPUDevice, qint8, float,
qint32, qint32, true, false>);
// Register a templatized implementation of // Register a templatized implementation of
// MklQuantizedConv2DWithBiasAndReluAndRequantize. // MklQuantizedConv2DWithBiasAndReluAndRequantize.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<float>("Tbias")
.TypeConstraint<float>("Tbias") .TypeConstraint<quint8>("out_type")
.TypeConstraint<quint8>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
MklQuantizedConv2DReluOp<CPUDevice, float, quint8, quint8, true, false>); quint8, quint8, true, false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("Tbias")
.TypeConstraint<qint32>("Tbias") .TypeConstraint<quint8>("out_type")
.TypeConstraint<quint8>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32,
MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, true, false>); quint8, quint8, true, false>);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<float>("Tbias")
.TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DReluOp<CPUDevice, qint8, float,
quint8, quint8, true, false>);
REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConv2DWithBiasAndReluAndRequantize")
.Device(DEVICE_CPU)
.TypeConstraint<qint8>("Tinput")
.TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("Tbias")
.TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DReluOp<CPUDevice, qint8, qint32,
quint8, quint8, true, false>);
// Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python // Register NoOp kernel for QuantizedConv2DWithBiasSumAndRelu to get a python
// interface. // interface.
@ -2107,7 +2202,8 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("out_type") .TypeConstraint<qint32>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DSumReluOp<CPUDevice, float, qint32, qint32, true, false>); MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, qint32, qint32, true,
false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
@ -2117,7 +2213,7 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<qint32>("Tbias") .TypeConstraint<qint32>("Tbias")
.TypeConstraint<quint8>("out_type") .TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, quint8, true, MklQuantizedConv2DSumReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
false>); false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
@ -2128,7 +2224,8 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<qint32>("Tbias") .TypeConstraint<qint32>("Tbias")
.TypeConstraint<quint8>("out_type") .TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DSumReluOp<CPUDevice, qint32, quint8, qint8, true, false>); MklQuantizedConv2DSumReluOp<CPUDevice, quint8, qint32, quint8, qint8, true,
false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize") Name("_MklQuantizedConv2DWithBiasSumAndReluAndRequantize")
@ -2138,7 +2235,8 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<float>("Tbias") .TypeConstraint<float>("Tbias")
.TypeConstraint<quint8>("out_type") .TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, quint8, true, false>); MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, quint8, quint8, true,
false>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize") Name("_MklQuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
@ -2148,7 +2246,8 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<float>("Tbias") .TypeConstraint<float>("Tbias")
.TypeConstraint<quint8>("out_type") .TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DSumReluOp<CPUDevice, float, quint8, qint8, true, false>); MklQuantizedConv2DSumReluOp<CPUDevice, quint8, float, quint8, qint8, true,
false>);
// Register NoOp kernels for non-fused and fused versions of // Register NoOp kernels for non-fused and fused versions of
// QuantizedDepthwiseConv2D to get a Python interface. These kernels will be // QuantizedDepthwiseConv2D to get a Python interface. These kernels will be
@ -2184,14 +2283,14 @@ REGISTER_KERNEL_BUILDER(
// Register templatized MKL kernels for non-fused and fused-versions of // Register templatized MKL kernels for non-fused and fused-versions of
// QuantizedDepthwiseConv2D. // QuantizedDepthwiseConv2D.
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2D")
Name("_MklQuantizedDepthwiseConv2D") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32,
MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, false, true>); qint32, false, true>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
Name("_MklQuantizedDepthwiseConv2DWithBias") Name("_MklQuantizedDepthwiseConv2DWithBias")
@ -2200,16 +2299,16 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint32>("out_type") .TypeConstraint<qint32>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DOp<CPUDevice, float, qint32, qint32, true, true>); MklQuantizedConv2DOp<CPUDevice, quint8, float, qint32, qint32, true, true>);
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(Name("_MklQuantizedDepthwiseConv2DWithBiasAndRelu")
Name("_MklQuantizedDepthwiseConv2DWithBiasAndRelu") .Device(DEVICE_CPU)
.Device(DEVICE_CPU) .TypeConstraint<quint8>("Tinput")
.TypeConstraint<quint8>("Tinput") .TypeConstraint<qint8>("Tfilter")
.TypeConstraint<qint8>("Tfilter") .TypeConstraint<qint32>("out_type")
.TypeConstraint<qint32>("out_type") .Label(mkl_op_registry::kMklQuantizedOpLabel),
.Label(mkl_op_registry::kMklQuantizedOpLabel), MklQuantizedConv2DReluOp<CPUDevice, quint8, float,
MklQuantizedConv2DReluOp<CPUDevice, float, qint32, qint32, true, true>); qint32, qint32, true, true>);
// Tbias -> float // Tbias -> float
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
@ -2220,7 +2319,8 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<float>("Tbias") .TypeConstraint<float>("Tbias")
.TypeConstraint<quint8>("out_type") .TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DReluOp<CPUDevice, float, quint8, quint8, true, true>); MklQuantizedConv2DReluOp<CPUDevice, quint8, float, quint8, quint8, true,
true>);
// Tbias -> qint32 // Tbias -> qint32
REGISTER_KERNEL_BUILDER( REGISTER_KERNEL_BUILDER(
@ -2231,7 +2331,8 @@ REGISTER_KERNEL_BUILDER(
.TypeConstraint<qint32>("Tbias") .TypeConstraint<qint32>("Tbias")
.TypeConstraint<quint8>("out_type") .TypeConstraint<quint8>("out_type")
.Label(mkl_op_registry::kMklQuantizedOpLabel), .Label(mkl_op_registry::kMklQuantizedOpLabel),
MklQuantizedConv2DReluOp<CPUDevice, qint32, quint8, quint8, true, true>); MklQuantizedConv2DReluOp<CPUDevice, quint8, qint32, quint8, quint8, true,
true>);
// Register 2D operations // Register 2D operations
#define REGISTER_MKL_CPU_2D(T) \ #define REGISTER_MKL_CPU_2D(T) \

View File

@ -285,6 +285,111 @@ TEST_F(QuantizedConv2DTest, Small) {
test::ExpectTensorNear<float>(expected_float, output_float, 1.0); test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
} }
TEST_F(QuantizedConv2DTest, SmallS8) {
const int stride = 1;
const int depth = 1;
const int image_width = 3;
const int image_height = 3;
const int image_batch_count = 1;
// Image -> uint8
const float image_min = -127.0f;
const float image_max = 127.0f;
TF_ASSERT_OK(NodeDefBuilder("quantized_conv_op", "_MklQuantizedConv2D")
.Input(FakeInput(DT_QINT8)) // Input
.Input(FakeInput(DT_QINT8)) // Filter
.Input(FakeInput(DT_FLOAT)) // Min input
.Input(FakeInput(DT_FLOAT)) // Max input
.Input(FakeInput(DT_FLOAT)) // Min filter
.Input(FakeInput(DT_FLOAT)) // Max filter
// MKL metadata tensors //
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
.Input(FakeInput(DT_UINT8))
///////////////////////////
.Attr("Tinput", DataTypeToEnum<qint8>::v())
.Attr("Tfilter", DataTypeToEnum<qint8>::v())
.Attr("T", DataTypeToEnum<quint8>::v())
.Attr("padding", "VALID")
.Attr("out_type", DataTypeToEnum<qint32>::v())
.Attr("strides", {1, stride, stride, 1})
.Attr("_kernel", "QuantizedMklOp")
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// The image matrix is:
// | 2 | 3 | 4 |
// | 6 | -4 | -2 |
// | 3 | 0 | 4 |
Tensor image_float(DT_FLOAT,
{image_batch_count, image_height, image_width, depth});
test::FillValues<float>(&image_float, {2, 3, 4, 6, -4, -2, 3, 0, 4});
Tensor image_quantized =
FloatTensorToQuantized<qint8>(image_float, image_min, image_max);
const int filter_size = 3;
const int filter_count = 1;
// Filter -> int8 with symmetric range
const float filter_min = -127.0f;
const float filter_max = 127.0f;
// The filter matrix is:
// | 1 | 4 | 2 |
// | 0 | 5 |-1 |
// | 3 |-1 |-3 |
Tensor filter_float(DT_FLOAT,
{filter_size, filter_size, depth, filter_count});
test::FillValues<float>(&filter_float, {1, 4, 2, 0, 5, -1, 3, -1, -3});
Tensor filter_quantized =
FloatTensorToQuantized<qint8>(filter_float, filter_min, filter_max);
AddInputFromArray<qint8>(image_quantized.shape(),
image_quantized.flat<qint8>());
AddInputFromArray<qint8>(filter_quantized.shape(),
filter_quantized.flat<qint8>());
AddInputFromArray<float>(TensorShape({1}), {image_min});
AddInputFromArray<float>(TensorShape({1}), {image_max});
AddInputFromArray<float>(TensorShape({1}), {filter_min});
AddInputFromArray<float>(TensorShape({1}), {filter_max});
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
AddInputFromArray<uint8>(dummy_shape, dummy_tensor);
TF_ASSERT_OK(RunOpKernel());
// Output -> float
const int expected_width = 1;
const int expected_height = 1;
Tensor expected_float(
DT_FLOAT, TensorShape({image_batch_count, expected_height, expected_width,
filter_count}));
test::FillValues<float>(&expected_float, {1});
const Tensor& output = *GetOutput(0);
const Tensor& output_mkl_metadata = *GetOutput(3);
ConvMklToTF conv_comp;
Tensor output_quantized;
conv_comp.ConvertMklToTF<qint32>(DT_QINT32, output, output_mkl_metadata,
output_quantized);
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
Tensor output_float =
QuantizedTensorToFloat<qint32>(output_quantized, output_min, output_max);
test::ExpectTensorNear<float>(expected_float, output_float, 1.0);
}
// Output -> qint32 // Output -> qint32
TEST_F(QuantizedConv2DTest, Small32Bit) { TEST_F(QuantizedConv2DTest, Small32Bit) {
const int stride = 1; const int stride = 1;