Merge pull request #25203 from Intel-tensorflow:requantization_op_perchannel_support

PiperOrigin-RevId: 233491842
This commit is contained in:
TensorFlower Gardener 2019-02-11 15:38:17 -08:00
commit d3e43986e2
11 changed files with 863 additions and 2 deletions

View File

@ -1473,6 +1473,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_identity_op", "//tensorflow/core/kernels:mkl_identity_op",
"//tensorflow/core/kernels:mkl_input_conversion_op", "//tensorflow/core/kernels:mkl_input_conversion_op",
"//tensorflow/core/kernels:mkl_lrn_op", "//tensorflow/core/kernels:mkl_lrn_op",
"//tensorflow/core/kernels:mkl_requantize_ops",
"//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_reshape_op",

View File

@ -0,0 +1,48 @@
op {
graph_op_name: "RequantizationRangePerChannel"
visibility : HIDDEN
in_arg {
name: "input"
description: <<END
The original input tensor.
END
}
in_arg {
name: "input_min"
description: <<END
The minimum value of the input tensor
END
}
in_arg {
name: "input_max"
description: <<END
The maximum value of the input tensor.
END
}
out_arg {
name: "output_min"
description: <<END
The minimum value of the final output tensor
END
}
out_arg {
name: "output_max"
description: <<END
The maximum value of the final output tensor.
END
}
attr {
name: "T"
description: <<END
The quantized type of input tensor that needs to be converted.
END
}
attr {
name: "clip_value_max"
description: <<END
The maximum value of the output that needs to be clipped.
Example: set this to 6 for Relu6.
END
}
summary: "Computes requantization range per channel."
}

View File

@ -0,0 +1,65 @@
op {
graph_op_name: "RequantizePerChannel"
visibility : HIDDEN
in_arg {
name: "input"
description: <<END
The original input tensor.
END
}
in_arg {
name: "input_min"
description: <<END
The minimum value of the input tensor
END
}
in_arg {
name: "input_max"
description: <<END
The maximum value of the input tensor.
END
}
in_arg {
name: "requested_output_min"
description: <<END
The minimum value of the output tensor requested.
END
}
in_arg {
name: "requested_output_max"
description: <<END
The maximum value of the output tensor requested.
END
}
out_arg {
name: "output"
description: <<END
Output tensor.
END
}
out_arg {
name: "output_min"
description: <<END
The minimum value of the final output tensor
END
}
out_arg {
name: "output_max"
description: <<END
The maximum value of the final output tensor.
END
}
attr {
name: "T"
description: <<END
The quantized type of input tensor that needs to be converted.
END
}
attr {
name: "out_type"
description: <<END
The quantized type of output tensor that needs to be converted.
END
}
summary: "Requantizes input with min and max values known per channel."
}

View File

@ -6987,6 +6987,68 @@ tf_mkl_kernel_library(
deps = NN_DEPS + mkl_deps() + [":cwise_op"], deps = NN_DEPS + mkl_deps() + [":cwise_op"],
) )
tf_mkl_kernel_library(
name = "mkl_requantize_ops",
srcs = [
"mkl_requantization_range_per_channel_op.cc",
"mkl_requantize_per_channel_op.cc",
],
hdrs = [
"meta_support.h",
"no_op.h",
"reference_gemm.h",
],
deps = if_mkl(
[
":concat_lib_hdrs",
":conv_ops",
":cwise_op",
":eigen_helpers",
":image_resizer_state",
":ops_util",
":pooling_ops",
":quantization_utils",
":transpose_functor",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
"//third_party/mkl:intel_binary_blob",
"@gemmlowp",
"@mkl_dnn",
],
),
)
tf_cc_test_mkl(
name = "mkl_requantize_ops_test",
size = "small",
srcs = ["mkl_requantize_ops_test.cc"],
tags = ["no_mac"], #TODO(penporn): Re-enable the test on MacOS.
deps = [
":mkl_requantize_ops",
":ops_testutil",
":ops_util",
":quantization_utils",
":quantized_ops",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_cc_test_mkl( tf_cc_test_mkl(
name = "mkl_fused_ops_test", name = "mkl_fused_ops_test",
size = "small", size = "small",

View File

@ -24,8 +24,13 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
template <class T> template <class T>
float MklFloatForOneQuantizedLevel(float range_min, float range_max) { float MklFloatForOneQuantizedLevel(float range_min, float range_max) {
const int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest()); int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
const int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest()); int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
// Adjusting for having a symmetric range.
// for example: for 8-bit [-127, 127] as opposed to [-128, 127].
if (lowest < -highest) ++lowest;
const float float_for_one_quantized_level = const float float_for_one_quantized_level =
(range_max - range_min) / (highest - lowest); (range_max - range_min) / (highest - lowest);
return float_for_one_quantized_level; return float_for_one_quantized_level;
@ -48,6 +53,35 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
*min_c = c_float_for_one_quant_level * c_lowest; *min_c = c_float_for_one_quant_level * c_lowest;
*max_c = c_float_for_one_quant_level * c_highest; *max_c = c_float_for_one_quant_level * c_highest;
} }
template <class T1, class T2, class T3>
void MklQuantizationRangeForMultiplication(float min_a, float max_a,
const Tensor& min_b_vector,
const Tensor& max_b_vector,
Tensor** min_c_vector,
Tensor** max_c_vector) {
DCHECK(min_b_vector.NumElements() == (*min_c_vector)->NumElements());
DCHECK(max_b_vector.NumElements() == (*max_c_vector)->NumElements());
size_t n_channel = min_b_vector.NumElements();
const int64 c_highest = static_cast<int64>(Eigen::NumTraits<T3>::highest());
const int64 c_lowest = static_cast<int64>(Eigen::NumTraits<T3>::lowest());
const float* min_b = min_b_vector.flat<float>().data();
const float* max_b = max_b_vector.flat<float>().data();
float* min_c = (*min_c_vector)->flat<float>().data();
float* max_c = (*max_c_vector)->flat<float>().data();
#pragma omp parallel for
for (size_t n = 0; n < n_channel; ++n) {
float a_float_for_one_quant_level =
MklFloatForOneQuantizedLevel<T1>(min_a, max_a);
float b_float_for_one_quant_level =
MklFloatForOneQuantizedLevel<T2>(min_b[n], max_b[n]);
float c_float_for_one_quant_level =
a_float_for_one_quant_level * b_float_for_one_quant_level;
min_c[n] = c_float_for_one_quant_level * c_lowest;
max_c[n] = c_float_for_one_quant_level * c_highest;
}
}
} // namespace tensorflow } // namespace tensorflow
#endif // INTEL_MKL #endif // INTEL_MKL

View File

@ -0,0 +1,124 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/array_ops.cc.
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
#include <math.h>
#include <limits>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
class MklRequantizationRangePerChannelOp : public OpKernel {
public:
explicit MklRequantizationRangePerChannelOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_value_max", &clip_value_max_));
}
void Compute(OpKernelContext* ctx) override {
const Tensor& input = ctx->input(kInputTensorIndex);
const Tensor& input_min = ctx->input(kInputMinIndex);
const Tensor& input_max = ctx->input(kInputMaxIndex);
const size_t depth = input_max.NumElements();
OP_REQUIRES(
ctx, input_min.dim_size(0) == depth,
errors::InvalidArgument("input_min has incorrect size, expected ",
depth, " was ", input_min.dim_size(0)));
OP_REQUIRES(
ctx, input_max.dim_size(0) == depth,
errors::InvalidArgument("input_max has incorrect size, expected ",
depth, " was ", input_max.dim_size(0)));
const float* input_min_data = input_min.flat<float>().data();
const float* input_max_data = input_max.flat<float>().data();
std::vector<float> ranges(depth);
bool is_non_negative = true;
Eigen::array<int, 2> shuffling({1, 0});
auto input_matrix = input.flat_inner_dims<qint32>();
// TODO: verify performance of not transposing and finding the min max
// directly from input_matrix vs the one presented below of transposing and
// using the transposed matrix as the transposing operation in itself might
// be more costly.
// Note that this operation is a calibration step for quantization and will
// cease to exist in the final inference graph(will exist as a const node).
auto transposed_input = input_matrix.shuffle(shuffling);
// Find the ranges of each channel in parallel.
float out_min_max = std::numeric_limits<float>::min();
#pragma omp parallel for reduction(max : out_min_max)
for (size_t i = 0; i < depth; ++i) {
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
transposed_input.chip<0>(i).minimum();
Eigen::Tensor<qint32, 0, Eigen::RowMajor> max =
transposed_input.chip<0>(i).maximum();
const int32_t min_per_channel = min();
const int32_t max_per_channel = max();
const int32_t abs_max =
std::max(std::abs(min_per_channel), std::abs(max_per_channel));
float scale =
std::max(std::abs(input_min_data[i]), std::abs(input_max_data[i]));
ranges[i] =
scale * static_cast<float>(abs_max) / static_cast<float>(1L << 31);
if (min_per_channel < 0) is_non_negative = false;
// Thread-local out_min_max.
out_min_max = std::max(out_min_max, ranges[i]);
}
// All local out_min_max gets max-reduced into one global out_min_max at
// the end of the loop by specifying reduction(max:out_min_max) along with
// omp parallel for.
// Fixing max to clip_value_max_ (example 6.0 to support relu6)
if (out_min_max > clip_value_max_) out_min_max = clip_value_max_;
Tensor* output_min = nullptr;
Tensor* output_max = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputMinIndex, {}, &output_min));
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputMaxIndex, {}, &output_max));
output_min->flat<float>()(0) = is_non_negative ? 0.0f : -out_min_max;
output_max->flat<float>()(0) = out_min_max;
}
private:
float clip_value_max_ = std::numeric_limits<float>::infinity();
const int kInputTensorIndex = 0;
const int kInputMinIndex = 1;
const int kInputMaxIndex = 2;
const int kOutputMinIndex = 0;
const int kOutputMaxIndex = 1;
};
REGISTER_KERNEL_BUILDER(Name("RequantizationRangePerChannel")
.Device(DEVICE_CPU)
.TypeConstraint<qint32>("T"),
MklRequantizationRangePerChannelOp);
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -0,0 +1,300 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include <cmath>
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
class MklRequantizatedOpsTest : public OpsTestBase {};
class MklRequantizatedOpsTestHelper : public OpsTestBase {
public:
void Setup(Tensor &input_tensor_qint32, float &range_weights_ch1,
float &range_weights_ch2);
void TestBody() {}
};
void MklRequantizatedOpsTestHelper::Setup(Tensor &input_tensor_qint32,
float &range_weights_ch1,
float &range_weights_ch2) {
// Step 1: Input range assumptions
// -------------------------------
// Assume input tensor T (NHWC) in FP32 has range [0, 5.0] size nt*ht*wt*ct
// Assume input filter W (NHWC) with 2 output channels of size nw*ht*wt*2
// logically, filter W has 2 channels W1 and W2 each of size nw*ht*wt*1
// Assume input filter W1(NHWC) in FP32 has range [-2.0, 2.0]size nw*ht*wt*1
// Assume input filter W2(NHWC) in FP32 has range [-3.0, 3.0]size nw*ht*wt*1
// Step 2: Quantization details (per channel)
// ------------------------------------------
// T and W are quantized using a quantize op.
// The input tensor T (NHWC) is quantized to unsigned int8.
// Hence T's max value is mapped to ((2^8-1) = 255).
// The input filter W (NHWC) is quantized to signed int8.
// Hence W's max value is mapped to ((2^7)-1 = 127)).
// Range of quantized T in uint8[0 , 255] maps to orig T in FP32[0 , 5.0]
// Range of quantized W1 in int8[-127, 127] maps to orig W1 in FP32[-2.0, 2.0]
// Range of quantized W2 in int8[-127, 127] maps to orig W2 in FP32[-3.0, 3.0]
// Hence the resolution of quantized T will be 5.0/255
// Hence the resolution of quantized W1 will be 2.0/127
// Hence the resolution of quantized W2 will be 3.0/127
// Step 3: Assumption of quantizedconv on quantized input&weights(per channel)
// ---------------------------------------------------------------------------
// The input T and weights W1 (or W2) will be convolved.
// The output tensor T is in int32 whose range is [-2^31, 2^31).
// For simplicity and symmetry, we truncate the above range to (-2^31, 2^31).
// The range of convolved T*W1 is ((2^31)-1) * 5.0/255 * 2.0/127 = 663110.59
// So the range of convolved T*W1 in int32(-2^31, 2^31) that maps to
// orig T range in FP32[0, 5.0] * [-2.0, 2.0] is [-663110.59, 663110.59].
// The range of convolved T*W2 is (2^31-1) * 5.0/255 * 3.0/127 = 994665.88
// So the range of convolved T*W2 in int32(-2^31, 2^31) that maps to
// orig T range in FP32 [0, 5.0] * [-3.0, 3.0] is [-994665.88, 994665.88]
// Step 4: Assumption output above is fed to requantization_range_perchannel
// --------------------------------------------------------------------------
// Here we recalculate the new range for convolved T*W so that we
// make good use in int8 quantization from int32 to int8.
// We assume the above operations are performed and use these values above
// as ranges for requantization_range_perchannel_op.
range_weights_ch1 = 663110.59; // For W1 channel
range_weights_ch2 = 994665.88; // For W2 Channel
// We Fill the input tensor T qint32 with arbitrary int32 values
test::FillValues<qint32>(
&input_tensor_qint32,
{-1000, -2000, 2000, 4000, -3000, -6000, 4000, 8000,
5000, 10000, -6000, -12000, 7000, 14000, 8000, 16000,
9000, -18000, -10000, -20000, 11000, 22000, -12000, -24000,
13000, 26000, 14000, 28000, -15000, -30000, 16000, 32000});
// Step 5: Define and run requantization_range_perchannel
// -------------------------------------------------------
// See test RequantizationRangePerChannelTest_Basic and/or
// test RequantizationRangePerChannelTest_ClipMax
}
// Tests the RequantizationRangePerChannel op wherein the range
// of the weights is calculated per channel.
TEST_F(MklRequantizatedOpsTest, RequantizationRangePerChannelTest_Basic) {
// Let us set up the tensor and inputs before we run this op.
float clip_value_max = static_cast<float>((1L << 31) - 1);
float range_weights_ch1 = 0.0;
float range_weights_ch2 = 0.0;
// Create the input tensor
const int input_height = 4;
const int input_width = 4;
const int input_channels = 2;
// Define the shape of T.
Tensor input_tensor_qint32(DT_QINT32,
{1, input_height, input_width, input_channels});
// Explanation and setup prior to this op. Fill T and populate range values.
MklRequantizatedOpsTestHelper helper;
helper.Setup(input_tensor_qint32, range_weights_ch1, range_weights_ch2);
// Step 5: Define and run requantization_range_perchannel
// -------------------------------------------------------
// Define, create and initialize the op in question.
TF_ASSERT_OK(NodeDefBuilder("requantization_range_per_channel",
"RequantizationRangePerChannel")
.Input(FakeInput(DT_QINT32))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<qint32>::v())
.Attr("clip_value_max", clip_value_max)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Add the input nodes to the op.
AddInputFromArray<qint32>(input_tensor_qint32.shape(),
input_tensor_qint32.flat<qint32>());
// Calculate the min and max from the ranges
float ch1_min = -range_weights_ch1;
float ch1_max = range_weights_ch1;
float ch2_min = -range_weights_ch2;
float ch2_max = range_weights_ch2;
// Add the perchannel range Nodes to the op.
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_min, ch2_min});
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_max, ch2_max});
// Run the kernel
TF_ASSERT_OK(RunOpKernel());
// Step 6: Verify output and store values to test requantize_perchannel
// --------------------------------------------------------------------
// Verify the Expected Outputs
const float output_min = GetOutput(0)->flat<float>()(0);
const float output_max = GetOutput(1)->flat<float>()(0);
EXPECT_NEAR(-14.8217, output_min, 0.002);
EXPECT_NEAR(14.8217, output_max, 0.002);
// Output range is made use in RequantizePerChannelTest_Basic
}
TEST_F(MklRequantizatedOpsTest, RequantizationRangePerChannelTest_ClipMax) {
// Let us setup the tensor and inputs before we run this op.
float clip_value_max = 6; // Can be used as 6 for Relu 6 activations.
float range_weights_ch1 = 0.0;
float range_weights_ch2 = 0.0;
// Create the input tensor
const int input_height = 4;
const int input_width = 4;
const int input_channels = 2;
// define and input tensor T shape.
Tensor input_tensor_qint32(DT_QINT32,
{1, input_height, input_width, input_channels});
// Explanation and setup prior to this op. Fill T and populate range values.
MklRequantizatedOpsTestHelper helper;
helper.Setup(input_tensor_qint32, range_weights_ch1, range_weights_ch2);
// Step 5: Define and run requantization_range_perchannel
// -------------------------------------------------------
// Define, create and initialize the op in question.
TF_ASSERT_OK(NodeDefBuilder("requantization_range_per_channel",
"RequantizationRangePerChannel")
.Input(FakeInput(DT_QINT32))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<qint32>::v())
.Attr("clip_value_max", clip_value_max)
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Add the input nodes to the op.
AddInputFromArray<qint32>(input_tensor_qint32.shape(),
input_tensor_qint32.flat<qint32>());
// Calculate the min and max from the ranges
float ch1_min = -range_weights_ch1;
float ch1_max = range_weights_ch1;
float ch2_min = -range_weights_ch2;
float ch2_max = range_weights_ch2;
// Add the perchannel range nodes to the op.
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_min, ch2_min});
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_max, ch2_max});
// Run the kernel
TF_ASSERT_OK(RunOpKernel());
// Step 6: Verify output and store values to test requantize_perchannel
// --------------------------------------------------------------------
// Verify the expected outputs
const float output_min = GetOutput(0)->flat<float>()(0);
const float output_max = GetOutput(1)->flat<float>()(0);
EXPECT_NEAR(-6.0, output_min, 0.002); // Values are aligned with clip_value.
EXPECT_NEAR(6.0, output_max, 0.002); // Values are aligned with clip_value.
}
TEST_F(MklRequantizatedOpsTest, RequantizePerChannelTest_Basic) {
// Let us setup the tensor and inputs before we run this op.
float range_weights_ch1 = 0.0;
float range_weights_ch2 = 0.0;
// Create the input tensor
const int input_height = 4;
const int input_width = 4;
const int input_channels = 2;
// define an input tensor T shape.
Tensor input_tensor_qint32(DT_QINT32,
{1, input_height, input_width, input_channels});
// Explanation and setup prior to this op. Fill T and populate range values.
MklRequantizatedOpsTestHelper helper;
helper.Setup(input_tensor_qint32, range_weights_ch1, range_weights_ch2);
// Step 7: Define and run requantize_perchannel
// --------------------------------------------
// The output of requantization_range_op_per_channel which calculated the
// new ranges of int8 is fed to the requantize per channel op.
// Here the values of convolved T*W is converted from int32 to int8.
TF_ASSERT_OK(NodeDefBuilder("requantize_per_channel", "RequantizePerChannel")
.Input(FakeInput(DT_QINT32))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Input(FakeInput(DT_FLOAT))
.Attr("T", DataTypeToEnum<qint32>::v())
.Attr("out_type", DataTypeToEnum<qint8>::v())
.Finalize(node_def()));
TF_ASSERT_OK(InitOp());
// Add the input Nodes to the op.
AddInputFromArray<qint32>(input_tensor_qint32.shape(),
input_tensor_qint32.flat<qint32>());
// Calculate the min and max from the ranges
float ch1_min = -range_weights_ch1;
float ch1_max = range_weights_ch1;
float ch2_min = -range_weights_ch2;
float ch2_max = range_weights_ch2;
// Add the perchannel range nodes to the op.
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_min, ch2_min});
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_max, ch2_max});
// Calculate the min and max from Step 6 above
// in RequantizationRangePerChannelTest_Basic
float range_op_output_min = -14.8217;
float range_op_output_max = 14.8217;
// Add the requested_min and requested_max stored from Step 6.
AddInputFromArray<float>(TensorShape({1}), {range_op_output_min});
AddInputFromArray<float>(TensorShape({1}), {range_op_output_max});
// Run the kernel
TF_ASSERT_OK(RunOpKernel());
// Verify the output with the expected output
Tensor output = *GetOutput(0);
const float output_min = GetOutput(1)->flat<float>()(0);
const float output_max = GetOutput(2)->flat<float>()(0);
EXPECT_NEAR(range_op_output_min, output_min, 0.002);
EXPECT_NEAR(range_op_output_max, output_max, 0.002);
}
} // namespace tensorflow
#endif // INTEL_MKL && ENABLE_MKL

View File

@ -0,0 +1,172 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/array_ops.cc.
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
#include <math.h>
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename Toutput>
class MklRequantizePerChannelOp : public OpKernel {
public:
explicit MklRequantizePerChannelOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_type_));
OP_REQUIRES(ctx, out_type_ == DT_QINT8 || out_type_ == DT_QUINT8,
errors::InvalidArgument(
"out_type must be qint8 or quint8, but got: " + out_type_));
}
virtual ~MklRequantizePerChannelOp() {}
void Compute(OpKernelContext* ctx) override {
try {
const Tensor& input = ctx->input(kInputTensorIndex);
const Tensor& input_min_vec = ctx->input(kInputMinVecIndex);
float* input_min_vec_data = (float*)const_cast<void*>(
static_cast<const void*>(input_min_vec.flat<float>().data()));
const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex);
float* input_max_vec_data = (float*)const_cast<void*>(
static_cast<const void*>(input_max_vec.flat<float>().data()));
const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex);
const float input_requested_min_float =
input_requested_min.flat<float>()(0);
const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex);
const float input_requested_max_float =
input_requested_max.flat<float>()(0);
size_t depth = input_min_vec.NumElements();
OP_REQUIRES(
ctx, input_min_vec.dim_size(0) == depth,
errors::InvalidArgument("input_min has incorrect size, expected ",
depth, " was ", input_min_vec.dim_size(0)));
OP_REQUIRES(
ctx, input_max_vec.dim_size(0) == depth,
errors::InvalidArgument("input_max has incorrect size, expected ",
depth, " was ", input_max_vec.dim_size(0)));
if (out_type_ == DT_QINT8) DCHECK(input_requested_min_float < 0.0f);
const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f;
const float requested_min_max =
std::max(std::abs(input_requested_min_float),
std::abs(input_requested_max_float));
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputTensorIndex,
input.shape(), &output));
std::vector<float> scales(depth);
for (int i = 0; i < depth; ++i) {
float min_max_from_vec = std::max(std::abs(input_min_vec_data[i]),
std::abs(input_max_vec_data[i]));
scales[i] = factor * (min_max_from_vec / requested_min_max /
static_cast<float>(1L << 31));
}
mkldnn::primitive_attr reorder_attr;
reorder_attr.set_output_scales(2, scales);
memory::dims dims_mkl_order =
TFShapeToMklDnnDimsInNCHW(input.shape(), FORMAT_NHWC);
memory::desc input_md = memory::desc(dims_mkl_order, MklDnnType<qint32>(),
memory::format::nhwc);
memory::desc output_md =
(out_type_ == DT_QINT8)
? memory::desc(dims_mkl_order, MklDnnType<qint8>(),
memory::format::nhwc)
: memory::desc(dims_mkl_order, MklDnnType<quint8>(),
memory::format::nhwc);
memory::primitive_desc input_pd =
memory::primitive_desc(input_md, cpu_engine_);
memory::primitive_desc output_pd =
memory::primitive_desc(output_md, cpu_engine_);
void* input_buf =
static_cast<void*>(const_cast<qint32*>(input.flat<qint32>().data()));
void* output_buf;
if (out_type_ == DT_QINT8) {
output_buf = static_cast<void*>(
const_cast<qint8*>(output->flat<qint8>().data()));
} else {
output_buf = static_cast<void*>(
const_cast<quint8*>(output->flat<quint8>().data()));
}
std::unique_ptr<memory> input_mem_prim_(new memory(input_pd, input_buf));
std::unique_ptr<memory> output_mem_prim_(
new memory(output_pd, output_buf));
mkldnn::reorder::primitive_desc reorder_pd =
mkldnn::reorder::primitive_desc(input_pd, output_pd, reorder_attr);
std::vector<mkldnn::primitive> net;
net.push_back(
mkldnn::reorder(reorder_pd, *input_mem_prim_, *output_mem_prim_));
stream(stream::kind::eager).submit(net).wait();
Tensor* output_min = nullptr;
Tensor* output_max = nullptr;
OP_REQUIRES_OK(ctx,
ctx->allocate_output(kOutputMinIndex, {}, &output_min));
OP_REQUIRES_OK(ctx,
ctx->allocate_output(kOutputMaxIndex, {}, &output_max));
output_min->flat<float>()(0) = input_requested_min_float;
output_max->flat<float>()(0) = input_requested_max_float;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + std::string(e.message) + ", in file " +
std::string(__FILE__) + ":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
ctx, errors::Aborted("Operation received an exception:", error_msg));
}
}
private:
const int kInputTensorIndex = 0;
const int kInputMinVecIndex = 1;
const int kInputMaxVecIndex = 2;
const int kRequestMinIndex = 3;
const int kRequestMaxIndex = 4;
const int kOutputTensorIndex = 0;
const int kOutputMinIndex = 1;
const int kOutputMaxIndex = 2;
DataType out_type_;
engine cpu_engine_ = engine(engine::cpu, 0);
};
REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel")
.Device(DEVICE_CPU)
.TypeConstraint<qint32>("T")
.TypeConstraint<qint8>("out_type"),
MklRequantizePerChannelOp<CPUDevice, qint8>);
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -1749,6 +1749,45 @@ inputs: Must all be the same size and shape.
#endif // INTEL_MKL #endif // INTEL_MKL
REGISTER_OP("RequantizePerChannel")
.Input("input: T")
.Input("input_min: float")
.Input("input_max: float")
.Input("requested_output_min: float")
.Input("requested_output_max: float")
.Output("output: out_type")
.Output("output_min: float")
.Output("output_max: float")
.Attr("T: quantizedtype = DT_QINT32")
.Attr("out_type: quantizedtype = DT_QUINT8")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
c->set_output(1, c->Scalar());
c->set_output(2, c->Scalar());
return Status::OK();
});
REGISTER_OP("RequantizationRangePerChannel")
.Input("input: T")
.Input("input_min: float")
.Input("input_max: float")
.Output("output_min: float")
.Output("output_max: float")
.Attr("T: quantizedtype = DT_QINT32")
.Attr("clip_value_max: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
c->set_output(0, c->Scalar());
c->set_output(1, c->Scalar());
return Status::OK();
});
REGISTER_OP("NextAfter") REGISTER_OP("NextAfter")
.Attr("T: {float64, float32} = DT_FLOAT") .Attr("T: {float64, float32} = DT_FLOAT")
.Input("x1: T") .Input("x1: T")

View File

@ -2644,10 +2644,18 @@ tf_module {
name: "RequantizationRange" name: "RequantizationRange"
argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "RequantizationRangePerChannel"
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "Requantize" name: "Requantize"
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "RequantizePerChannel"
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "Reshape" name: "Reshape"
argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None"

View File

@ -2644,10 +2644,18 @@ tf_module {
name: "RequantizationRange" name: "RequantizationRange"
argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "RequantizationRangePerChannel"
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "Requantize" name: "Requantize"
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
} }
member_method {
name: "RequantizePerChannel"
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "Reshape" name: "Reshape"
argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None"