Merge pull request #25203 from Intel-tensorflow:requantization_op_perchannel_support
PiperOrigin-RevId: 233491842
This commit is contained in:
commit
d3e43986e2
@ -1473,6 +1473,7 @@ cc_library(
|
|||||||
"//tensorflow/core/kernels:mkl_identity_op",
|
"//tensorflow/core/kernels:mkl_identity_op",
|
||||||
"//tensorflow/core/kernels:mkl_input_conversion_op",
|
"//tensorflow/core/kernels:mkl_input_conversion_op",
|
||||||
"//tensorflow/core/kernels:mkl_lrn_op",
|
"//tensorflow/core/kernels:mkl_lrn_op",
|
||||||
|
"//tensorflow/core/kernels:mkl_requantize_ops",
|
||||||
"//tensorflow/core/kernels:mkl_pooling_ops",
|
"//tensorflow/core/kernels:mkl_pooling_ops",
|
||||||
"//tensorflow/core/kernels:mkl_relu_op",
|
"//tensorflow/core/kernels:mkl_relu_op",
|
||||||
"//tensorflow/core/kernels:mkl_reshape_op",
|
"//tensorflow/core/kernels:mkl_reshape_op",
|
||||||
|
@ -0,0 +1,48 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "RequantizationRangePerChannel"
|
||||||
|
visibility : HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: <<END
|
||||||
|
The original input tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "input_min"
|
||||||
|
description: <<END
|
||||||
|
The minimum value of the input tensor
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "input_max"
|
||||||
|
description: <<END
|
||||||
|
The maximum value of the input tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_min"
|
||||||
|
description: <<END
|
||||||
|
The minimum value of the final output tensor
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_max"
|
||||||
|
description: <<END
|
||||||
|
The maximum value of the final output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
description: <<END
|
||||||
|
The quantized type of input tensor that needs to be converted.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "clip_value_max"
|
||||||
|
description: <<END
|
||||||
|
The maximum value of the output that needs to be clipped.
|
||||||
|
Example: set this to 6 for Relu6.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Computes requantization range per channel."
|
||||||
|
}
|
@ -0,0 +1,65 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "RequantizePerChannel"
|
||||||
|
visibility : HIDDEN
|
||||||
|
in_arg {
|
||||||
|
name: "input"
|
||||||
|
description: <<END
|
||||||
|
The original input tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "input_min"
|
||||||
|
description: <<END
|
||||||
|
The minimum value of the input tensor
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "input_max"
|
||||||
|
description: <<END
|
||||||
|
The maximum value of the input tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "requested_output_min"
|
||||||
|
description: <<END
|
||||||
|
The minimum value of the output tensor requested.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
in_arg {
|
||||||
|
name: "requested_output_max"
|
||||||
|
description: <<END
|
||||||
|
The maximum value of the output tensor requested.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output"
|
||||||
|
description: <<END
|
||||||
|
Output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_min"
|
||||||
|
description: <<END
|
||||||
|
The minimum value of the final output tensor
|
||||||
|
END
|
||||||
|
}
|
||||||
|
out_arg {
|
||||||
|
name: "output_max"
|
||||||
|
description: <<END
|
||||||
|
The maximum value of the final output tensor.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
description: <<END
|
||||||
|
The quantized type of input tensor that needs to be converted.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "out_type"
|
||||||
|
description: <<END
|
||||||
|
The quantized type of output tensor that needs to be converted.
|
||||||
|
END
|
||||||
|
}
|
||||||
|
summary: "Requantizes input with min and max values known per channel."
|
||||||
|
}
|
@ -6987,6 +6987,68 @@ tf_mkl_kernel_library(
|
|||||||
deps = NN_DEPS + mkl_deps() + [":cwise_op"],
|
deps = NN_DEPS + mkl_deps() + [":cwise_op"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_mkl_kernel_library(
|
||||||
|
name = "mkl_requantize_ops",
|
||||||
|
srcs = [
|
||||||
|
"mkl_requantization_range_per_channel_op.cc",
|
||||||
|
"mkl_requantize_per_channel_op.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"meta_support.h",
|
||||||
|
"no_op.h",
|
||||||
|
"reference_gemm.h",
|
||||||
|
],
|
||||||
|
deps = if_mkl(
|
||||||
|
[
|
||||||
|
":concat_lib_hdrs",
|
||||||
|
":conv_ops",
|
||||||
|
":cwise_op",
|
||||||
|
":eigen_helpers",
|
||||||
|
":image_resizer_state",
|
||||||
|
":ops_util",
|
||||||
|
":pooling_ops",
|
||||||
|
":quantization_utils",
|
||||||
|
":transpose_functor",
|
||||||
|
"//tensorflow/core:array_ops_op_lib",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:math_ops_op_lib",
|
||||||
|
"//tensorflow/core:nn_ops_op_lib",
|
||||||
|
"//third_party/eigen3",
|
||||||
|
"//third_party/mkl:intel_binary_blob",
|
||||||
|
"@gemmlowp",
|
||||||
|
"@mkl_dnn",
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_cc_test_mkl(
|
||||||
|
name = "mkl_requantize_ops_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = ["mkl_requantize_ops_test.cc"],
|
||||||
|
tags = ["no_mac"], #TODO(penporn): Re-enable the test on MacOS.
|
||||||
|
deps = [
|
||||||
|
":mkl_requantize_ops",
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
":quantization_utils",
|
||||||
|
":quantized_ops",
|
||||||
|
"//tensorflow/cc:cc_ops",
|
||||||
|
"//tensorflow/core:array_ops_op_lib",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:math_ops_op_lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:tensorflow",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
tf_cc_test_mkl(
|
tf_cc_test_mkl(
|
||||||
name = "mkl_fused_ops_test",
|
name = "mkl_fused_ops_test",
|
||||||
size = "small",
|
size = "small",
|
||||||
|
@ -24,8 +24,13 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
template <class T>
|
template <class T>
|
||||||
float MklFloatForOneQuantizedLevel(float range_min, float range_max) {
|
float MklFloatForOneQuantizedLevel(float range_min, float range_max) {
|
||||||
const int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
|
int64 highest = static_cast<int64>(Eigen::NumTraits<T>::highest());
|
||||||
const int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
|
int64 lowest = static_cast<int64>(Eigen::NumTraits<T>::lowest());
|
||||||
|
|
||||||
|
// Adjusting for having a symmetric range.
|
||||||
|
// for example: for 8-bit [-127, 127] as opposed to [-128, 127].
|
||||||
|
if (lowest < -highest) ++lowest;
|
||||||
|
|
||||||
const float float_for_one_quantized_level =
|
const float float_for_one_quantized_level =
|
||||||
(range_max - range_min) / (highest - lowest);
|
(range_max - range_min) / (highest - lowest);
|
||||||
return float_for_one_quantized_level;
|
return float_for_one_quantized_level;
|
||||||
@ -48,6 +53,35 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
|
|||||||
*min_c = c_float_for_one_quant_level * c_lowest;
|
*min_c = c_float_for_one_quant_level * c_lowest;
|
||||||
*max_c = c_float_for_one_quant_level * c_highest;
|
*max_c = c_float_for_one_quant_level * c_highest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T1, class T2, class T3>
|
||||||
|
void MklQuantizationRangeForMultiplication(float min_a, float max_a,
|
||||||
|
const Tensor& min_b_vector,
|
||||||
|
const Tensor& max_b_vector,
|
||||||
|
Tensor** min_c_vector,
|
||||||
|
Tensor** max_c_vector) {
|
||||||
|
DCHECK(min_b_vector.NumElements() == (*min_c_vector)->NumElements());
|
||||||
|
DCHECK(max_b_vector.NumElements() == (*max_c_vector)->NumElements());
|
||||||
|
size_t n_channel = min_b_vector.NumElements();
|
||||||
|
const int64 c_highest = static_cast<int64>(Eigen::NumTraits<T3>::highest());
|
||||||
|
const int64 c_lowest = static_cast<int64>(Eigen::NumTraits<T3>::lowest());
|
||||||
|
const float* min_b = min_b_vector.flat<float>().data();
|
||||||
|
const float* max_b = max_b_vector.flat<float>().data();
|
||||||
|
float* min_c = (*min_c_vector)->flat<float>().data();
|
||||||
|
float* max_c = (*max_c_vector)->flat<float>().data();
|
||||||
|
#pragma omp parallel for
|
||||||
|
for (size_t n = 0; n < n_channel; ++n) {
|
||||||
|
float a_float_for_one_quant_level =
|
||||||
|
MklFloatForOneQuantizedLevel<T1>(min_a, max_a);
|
||||||
|
float b_float_for_one_quant_level =
|
||||||
|
MklFloatForOneQuantizedLevel<T2>(min_b[n], max_b[n]);
|
||||||
|
float c_float_for_one_quant_level =
|
||||||
|
a_float_for_one_quant_level * b_float_for_one_quant_level;
|
||||||
|
min_c[n] = c_float_for_one_quant_level * c_lowest;
|
||||||
|
max_c[n] = c_float_for_one_quant_level * c_highest;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
@ -0,0 +1,124 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/array_ops.cc.
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
|
#include <math.h>
|
||||||
|
#include <limits>
|
||||||
|
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/meta_support.h"
|
||||||
|
#include "tensorflow/core/kernels/no_op.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
class MklRequantizationRangePerChannelOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit MklRequantizationRangePerChannelOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("clip_value_max", &clip_value_max_));
|
||||||
|
}
|
||||||
|
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
const Tensor& input = ctx->input(kInputTensorIndex);
|
||||||
|
const Tensor& input_min = ctx->input(kInputMinIndex);
|
||||||
|
const Tensor& input_max = ctx->input(kInputMaxIndex);
|
||||||
|
|
||||||
|
const size_t depth = input_max.NumElements();
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_min.dim_size(0) == depth,
|
||||||
|
errors::InvalidArgument("input_min has incorrect size, expected ",
|
||||||
|
depth, " was ", input_min.dim_size(0)));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_max.dim_size(0) == depth,
|
||||||
|
errors::InvalidArgument("input_max has incorrect size, expected ",
|
||||||
|
depth, " was ", input_max.dim_size(0)));
|
||||||
|
|
||||||
|
const float* input_min_data = input_min.flat<float>().data();
|
||||||
|
const float* input_max_data = input_max.flat<float>().data();
|
||||||
|
std::vector<float> ranges(depth);
|
||||||
|
bool is_non_negative = true;
|
||||||
|
Eigen::array<int, 2> shuffling({1, 0});
|
||||||
|
auto input_matrix = input.flat_inner_dims<qint32>();
|
||||||
|
|
||||||
|
// TODO: verify performance of not transposing and finding the min max
|
||||||
|
// directly from input_matrix vs the one presented below of transposing and
|
||||||
|
// using the transposed matrix as the transposing operation in itself might
|
||||||
|
// be more costly.
|
||||||
|
// Note that this operation is a calibration step for quantization and will
|
||||||
|
// cease to exist in the final inference graph(will exist as a const node).
|
||||||
|
auto transposed_input = input_matrix.shuffle(shuffling);
|
||||||
|
|
||||||
|
// Find the ranges of each channel in parallel.
|
||||||
|
float out_min_max = std::numeric_limits<float>::min();
|
||||||
|
#pragma omp parallel for reduction(max : out_min_max)
|
||||||
|
for (size_t i = 0; i < depth; ++i) {
|
||||||
|
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
|
||||||
|
transposed_input.chip<0>(i).minimum();
|
||||||
|
Eigen::Tensor<qint32, 0, Eigen::RowMajor> max =
|
||||||
|
transposed_input.chip<0>(i).maximum();
|
||||||
|
const int32_t min_per_channel = min();
|
||||||
|
const int32_t max_per_channel = max();
|
||||||
|
const int32_t abs_max =
|
||||||
|
std::max(std::abs(min_per_channel), std::abs(max_per_channel));
|
||||||
|
float scale =
|
||||||
|
std::max(std::abs(input_min_data[i]), std::abs(input_max_data[i]));
|
||||||
|
ranges[i] =
|
||||||
|
scale * static_cast<float>(abs_max) / static_cast<float>(1L << 31);
|
||||||
|
if (min_per_channel < 0) is_non_negative = false;
|
||||||
|
|
||||||
|
// Thread-local out_min_max.
|
||||||
|
out_min_max = std::max(out_min_max, ranges[i]);
|
||||||
|
}
|
||||||
|
// All local out_min_max gets max-reduced into one global out_min_max at
|
||||||
|
// the end of the loop by specifying reduction(max:out_min_max) along with
|
||||||
|
// omp parallel for.
|
||||||
|
|
||||||
|
// Fixing max to clip_value_max_ (example 6.0 to support relu6)
|
||||||
|
if (out_min_max > clip_value_max_) out_min_max = clip_value_max_;
|
||||||
|
|
||||||
|
Tensor* output_min = nullptr;
|
||||||
|
Tensor* output_max = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputMinIndex, {}, &output_min));
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputMaxIndex, {}, &output_max));
|
||||||
|
output_min->flat<float>()(0) = is_non_negative ? 0.0f : -out_min_max;
|
||||||
|
output_max->flat<float>()(0) = out_min_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
float clip_value_max_ = std::numeric_limits<float>::infinity();
|
||||||
|
const int kInputTensorIndex = 0;
|
||||||
|
const int kInputMinIndex = 1;
|
||||||
|
const int kInputMaxIndex = 2;
|
||||||
|
const int kOutputMinIndex = 0;
|
||||||
|
const int kOutputMaxIndex = 1;
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RequantizationRangePerChannel")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<qint32>("T"),
|
||||||
|
MklRequantizationRangePerChannelOp);
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // INTEL_MKL
|
300
tensorflow/core/kernels/mkl_requantize_ops_test.cc
Normal file
300
tensorflow/core/kernels/mkl_requantize_ops_test.cc
Normal file
@ -0,0 +1,300 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#if defined(INTEL_MKL) && defined(ENABLE_MKL)
|
||||||
|
|
||||||
|
#include <cmath>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/allocator.h"
|
||||||
|
#include "tensorflow/core/framework/fake_input.h"
|
||||||
|
#include "tensorflow/core/framework/node_def_builder.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_testutil.h"
|
||||||
|
#include "tensorflow/core/kernels/ops_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
class MklRequantizatedOpsTest : public OpsTestBase {};
|
||||||
|
|
||||||
|
class MklRequantizatedOpsTestHelper : public OpsTestBase {
|
||||||
|
public:
|
||||||
|
void Setup(Tensor &input_tensor_qint32, float &range_weights_ch1,
|
||||||
|
float &range_weights_ch2);
|
||||||
|
void TestBody() {}
|
||||||
|
};
|
||||||
|
|
||||||
|
void MklRequantizatedOpsTestHelper::Setup(Tensor &input_tensor_qint32,
|
||||||
|
float &range_weights_ch1,
|
||||||
|
float &range_weights_ch2) {
|
||||||
|
// Step 1: Input range assumptions
|
||||||
|
// -------------------------------
|
||||||
|
// Assume input tensor T (NHWC) in FP32 has range [0, 5.0] size nt*ht*wt*ct
|
||||||
|
// Assume input filter W (NHWC) with 2 output channels of size nw*ht*wt*2
|
||||||
|
// logically, filter W has 2 channels W1 and W2 each of size nw*ht*wt*1
|
||||||
|
// Assume input filter W1(NHWC) in FP32 has range [-2.0, 2.0]size nw*ht*wt*1
|
||||||
|
// Assume input filter W2(NHWC) in FP32 has range [-3.0, 3.0]size nw*ht*wt*1
|
||||||
|
|
||||||
|
// Step 2: Quantization details (per channel)
|
||||||
|
// ------------------------------------------
|
||||||
|
// T and W are quantized using a quantize op.
|
||||||
|
// The input tensor T (NHWC) is quantized to unsigned int8.
|
||||||
|
// Hence T's max value is mapped to ((2^8-1) = 255).
|
||||||
|
// The input filter W (NHWC) is quantized to signed int8.
|
||||||
|
// Hence W's max value is mapped to ((2^7)-1 = 127)).
|
||||||
|
|
||||||
|
// Range of quantized T in uint8[0 , 255] maps to orig T in FP32[0 , 5.0]
|
||||||
|
// Range of quantized W1 in int8[-127, 127] maps to orig W1 in FP32[-2.0, 2.0]
|
||||||
|
// Range of quantized W2 in int8[-127, 127] maps to orig W2 in FP32[-3.0, 3.0]
|
||||||
|
|
||||||
|
// Hence the resolution of quantized T will be 5.0/255
|
||||||
|
// Hence the resolution of quantized W1 will be 2.0/127
|
||||||
|
// Hence the resolution of quantized W2 will be 3.0/127
|
||||||
|
|
||||||
|
// Step 3: Assumption of quantizedconv on quantized input&weights(per channel)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// The input T and weights W1 (or W2) will be convolved.
|
||||||
|
// The output tensor T is in int32 whose range is [-2^31, 2^31).
|
||||||
|
// For simplicity and symmetry, we truncate the above range to (-2^31, 2^31).
|
||||||
|
// The range of convolved T*W1 is ((2^31)-1) * 5.0/255 * 2.0/127 = 663110.59
|
||||||
|
// So the range of convolved T*W1 in int32(-2^31, 2^31) that maps to
|
||||||
|
// orig T range in FP32[0, 5.0] * [-2.0, 2.0] is [-663110.59, 663110.59].
|
||||||
|
|
||||||
|
// The range of convolved T*W2 is (2^31-1) * 5.0/255 * 3.0/127 = 994665.88
|
||||||
|
// So the range of convolved T*W2 in int32(-2^31, 2^31) that maps to
|
||||||
|
// orig T range in FP32 [0, 5.0] * [-3.0, 3.0] is [-994665.88, 994665.88]
|
||||||
|
|
||||||
|
// Step 4: Assumption output above is fed to requantization_range_perchannel
|
||||||
|
// --------------------------------------------------------------------------
|
||||||
|
// Here we recalculate the new range for convolved T*W so that we
|
||||||
|
// make good use in int8 quantization from int32 to int8.
|
||||||
|
|
||||||
|
// We assume the above operations are performed and use these values above
|
||||||
|
// as ranges for requantization_range_perchannel_op.
|
||||||
|
range_weights_ch1 = 663110.59; // For W1 channel
|
||||||
|
range_weights_ch2 = 994665.88; // For W2 Channel
|
||||||
|
|
||||||
|
// We Fill the input tensor T qint32 with arbitrary int32 values
|
||||||
|
test::FillValues<qint32>(
|
||||||
|
&input_tensor_qint32,
|
||||||
|
{-1000, -2000, 2000, 4000, -3000, -6000, 4000, 8000,
|
||||||
|
5000, 10000, -6000, -12000, 7000, 14000, 8000, 16000,
|
||||||
|
9000, -18000, -10000, -20000, 11000, 22000, -12000, -24000,
|
||||||
|
13000, 26000, 14000, 28000, -15000, -30000, 16000, 32000});
|
||||||
|
|
||||||
|
// Step 5: Define and run requantization_range_perchannel
|
||||||
|
// -------------------------------------------------------
|
||||||
|
// See test RequantizationRangePerChannelTest_Basic and/or
|
||||||
|
// test RequantizationRangePerChannelTest_ClipMax
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tests the RequantizationRangePerChannel op wherein the range
|
||||||
|
// of the weights is calculated per channel.
|
||||||
|
TEST_F(MklRequantizatedOpsTest, RequantizationRangePerChannelTest_Basic) {
|
||||||
|
// Let us set up the tensor and inputs before we run this op.
|
||||||
|
float clip_value_max = static_cast<float>((1L << 31) - 1);
|
||||||
|
float range_weights_ch1 = 0.0;
|
||||||
|
float range_weights_ch2 = 0.0;
|
||||||
|
|
||||||
|
// Create the input tensor
|
||||||
|
const int input_height = 4;
|
||||||
|
const int input_width = 4;
|
||||||
|
const int input_channels = 2;
|
||||||
|
|
||||||
|
// Define the shape of T.
|
||||||
|
Tensor input_tensor_qint32(DT_QINT32,
|
||||||
|
{1, input_height, input_width, input_channels});
|
||||||
|
|
||||||
|
// Explanation and setup prior to this op. Fill T and populate range values.
|
||||||
|
MklRequantizatedOpsTestHelper helper;
|
||||||
|
helper.Setup(input_tensor_qint32, range_weights_ch1, range_weights_ch2);
|
||||||
|
|
||||||
|
// Step 5: Define and run requantization_range_perchannel
|
||||||
|
// -------------------------------------------------------
|
||||||
|
// Define, create and initialize the op in question.
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("requantization_range_per_channel",
|
||||||
|
"RequantizationRangePerChannel")
|
||||||
|
.Input(FakeInput(DT_QINT32))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Attr("T", DataTypeToEnum<qint32>::v())
|
||||||
|
.Attr("clip_value_max", clip_value_max)
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
|
||||||
|
// Add the input nodes to the op.
|
||||||
|
AddInputFromArray<qint32>(input_tensor_qint32.shape(),
|
||||||
|
input_tensor_qint32.flat<qint32>());
|
||||||
|
|
||||||
|
// Calculate the min and max from the ranges
|
||||||
|
float ch1_min = -range_weights_ch1;
|
||||||
|
float ch1_max = range_weights_ch1;
|
||||||
|
float ch2_min = -range_weights_ch2;
|
||||||
|
float ch2_max = range_weights_ch2;
|
||||||
|
|
||||||
|
// Add the perchannel range Nodes to the op.
|
||||||
|
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_min, ch2_min});
|
||||||
|
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_max, ch2_max});
|
||||||
|
|
||||||
|
// Run the kernel
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Step 6: Verify output and store values to test requantize_perchannel
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Verify the Expected Outputs
|
||||||
|
const float output_min = GetOutput(0)->flat<float>()(0);
|
||||||
|
const float output_max = GetOutput(1)->flat<float>()(0);
|
||||||
|
EXPECT_NEAR(-14.8217, output_min, 0.002);
|
||||||
|
EXPECT_NEAR(14.8217, output_max, 0.002);
|
||||||
|
|
||||||
|
// Output range is made use in RequantizePerChannelTest_Basic
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklRequantizatedOpsTest, RequantizationRangePerChannelTest_ClipMax) {
|
||||||
|
// Let us setup the tensor and inputs before we run this op.
|
||||||
|
float clip_value_max = 6; // Can be used as 6 for Relu 6 activations.
|
||||||
|
float range_weights_ch1 = 0.0;
|
||||||
|
float range_weights_ch2 = 0.0;
|
||||||
|
|
||||||
|
// Create the input tensor
|
||||||
|
const int input_height = 4;
|
||||||
|
const int input_width = 4;
|
||||||
|
const int input_channels = 2;
|
||||||
|
|
||||||
|
// define and input tensor T shape.
|
||||||
|
Tensor input_tensor_qint32(DT_QINT32,
|
||||||
|
{1, input_height, input_width, input_channels});
|
||||||
|
|
||||||
|
// Explanation and setup prior to this op. Fill T and populate range values.
|
||||||
|
MklRequantizatedOpsTestHelper helper;
|
||||||
|
helper.Setup(input_tensor_qint32, range_weights_ch1, range_weights_ch2);
|
||||||
|
|
||||||
|
// Step 5: Define and run requantization_range_perchannel
|
||||||
|
// -------------------------------------------------------
|
||||||
|
// Define, create and initialize the op in question.
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("requantization_range_per_channel",
|
||||||
|
"RequantizationRangePerChannel")
|
||||||
|
.Input(FakeInput(DT_QINT32))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Attr("T", DataTypeToEnum<qint32>::v())
|
||||||
|
.Attr("clip_value_max", clip_value_max)
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
|
||||||
|
// Add the input nodes to the op.
|
||||||
|
AddInputFromArray<qint32>(input_tensor_qint32.shape(),
|
||||||
|
input_tensor_qint32.flat<qint32>());
|
||||||
|
|
||||||
|
// Calculate the min and max from the ranges
|
||||||
|
float ch1_min = -range_weights_ch1;
|
||||||
|
float ch1_max = range_weights_ch1;
|
||||||
|
float ch2_min = -range_weights_ch2;
|
||||||
|
float ch2_max = range_weights_ch2;
|
||||||
|
|
||||||
|
// Add the perchannel range nodes to the op.
|
||||||
|
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_min, ch2_min});
|
||||||
|
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_max, ch2_max});
|
||||||
|
|
||||||
|
// Run the kernel
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Step 6: Verify output and store values to test requantize_perchannel
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
// Verify the expected outputs
|
||||||
|
const float output_min = GetOutput(0)->flat<float>()(0);
|
||||||
|
const float output_max = GetOutput(1)->flat<float>()(0);
|
||||||
|
EXPECT_NEAR(-6.0, output_min, 0.002); // Values are aligned with clip_value.
|
||||||
|
EXPECT_NEAR(6.0, output_max, 0.002); // Values are aligned with clip_value.
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(MklRequantizatedOpsTest, RequantizePerChannelTest_Basic) {
|
||||||
|
// Let us setup the tensor and inputs before we run this op.
|
||||||
|
float range_weights_ch1 = 0.0;
|
||||||
|
float range_weights_ch2 = 0.0;
|
||||||
|
|
||||||
|
// Create the input tensor
|
||||||
|
const int input_height = 4;
|
||||||
|
const int input_width = 4;
|
||||||
|
const int input_channels = 2;
|
||||||
|
|
||||||
|
// define an input tensor T shape.
|
||||||
|
Tensor input_tensor_qint32(DT_QINT32,
|
||||||
|
{1, input_height, input_width, input_channels});
|
||||||
|
|
||||||
|
// Explanation and setup prior to this op. Fill T and populate range values.
|
||||||
|
MklRequantizatedOpsTestHelper helper;
|
||||||
|
helper.Setup(input_tensor_qint32, range_weights_ch1, range_weights_ch2);
|
||||||
|
|
||||||
|
// Step 7: Define and run requantize_perchannel
|
||||||
|
// --------------------------------------------
|
||||||
|
// The output of requantization_range_op_per_channel which calculated the
|
||||||
|
// new ranges of int8 is fed to the requantize per channel op.
|
||||||
|
// Here the values of convolved T*W is converted from int32 to int8.
|
||||||
|
|
||||||
|
TF_ASSERT_OK(NodeDefBuilder("requantize_per_channel", "RequantizePerChannel")
|
||||||
|
.Input(FakeInput(DT_QINT32))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Input(FakeInput(DT_FLOAT))
|
||||||
|
.Attr("T", DataTypeToEnum<qint32>::v())
|
||||||
|
.Attr("out_type", DataTypeToEnum<qint8>::v())
|
||||||
|
.Finalize(node_def()));
|
||||||
|
TF_ASSERT_OK(InitOp());
|
||||||
|
|
||||||
|
// Add the input Nodes to the op.
|
||||||
|
AddInputFromArray<qint32>(input_tensor_qint32.shape(),
|
||||||
|
input_tensor_qint32.flat<qint32>());
|
||||||
|
|
||||||
|
// Calculate the min and max from the ranges
|
||||||
|
float ch1_min = -range_weights_ch1;
|
||||||
|
float ch1_max = range_weights_ch1;
|
||||||
|
float ch2_min = -range_weights_ch2;
|
||||||
|
float ch2_max = range_weights_ch2;
|
||||||
|
|
||||||
|
// Add the perchannel range nodes to the op.
|
||||||
|
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_min, ch2_min});
|
||||||
|
AddInputFromArray<float>(TensorShape({input_channels}), {ch1_max, ch2_max});
|
||||||
|
|
||||||
|
// Calculate the min and max from Step 6 above
|
||||||
|
// in RequantizationRangePerChannelTest_Basic
|
||||||
|
float range_op_output_min = -14.8217;
|
||||||
|
float range_op_output_max = 14.8217;
|
||||||
|
|
||||||
|
// Add the requested_min and requested_max stored from Step 6.
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {range_op_output_min});
|
||||||
|
AddInputFromArray<float>(TensorShape({1}), {range_op_output_max});
|
||||||
|
|
||||||
|
// Run the kernel
|
||||||
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
|
|
||||||
|
// Verify the output with the expected output
|
||||||
|
Tensor output = *GetOutput(0);
|
||||||
|
const float output_min = GetOutput(1)->flat<float>()(0);
|
||||||
|
const float output_max = GetOutput(2)->flat<float>()(0);
|
||||||
|
EXPECT_NEAR(range_op_output_min, output_min, 0.002);
|
||||||
|
EXPECT_NEAR(range_op_output_max, output_max, 0.002);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // INTEL_MKL && ENABLE_MKL
|
172
tensorflow/core/kernels/mkl_requantize_per_channel_op.cc
Normal file
172
tensorflow/core/kernels/mkl_requantize_per_channel_op.cc
Normal file
@ -0,0 +1,172 @@
|
|||||||
|
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
// See docs in ../ops/array_ops.cc.
|
||||||
|
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
#include <math.h>
|
||||||
|
|
||||||
|
#include "mkldnn.hpp"
|
||||||
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
|
#include "tensorflow/core/framework/op.h"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/framework/type_traits.h"
|
||||||
|
#include "tensorflow/core/framework/types.h"
|
||||||
|
#include "tensorflow/core/kernels/meta_support.h"
|
||||||
|
#include "tensorflow/core/kernels/no_op.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
|
template <typename Device, typename Toutput>
|
||||||
|
class MklRequantizePerChannelOp : public OpKernel {
|
||||||
|
public:
|
||||||
|
explicit MklRequantizePerChannelOp(OpKernelConstruction* ctx)
|
||||||
|
: OpKernel(ctx) {
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_type_));
|
||||||
|
OP_REQUIRES(ctx, out_type_ == DT_QINT8 || out_type_ == DT_QUINT8,
|
||||||
|
errors::InvalidArgument(
|
||||||
|
"out_type must be qint8 or quint8, but got: " + out_type_));
|
||||||
|
}
|
||||||
|
virtual ~MklRequantizePerChannelOp() {}
|
||||||
|
void Compute(OpKernelContext* ctx) override {
|
||||||
|
try {
|
||||||
|
const Tensor& input = ctx->input(kInputTensorIndex);
|
||||||
|
const Tensor& input_min_vec = ctx->input(kInputMinVecIndex);
|
||||||
|
float* input_min_vec_data = (float*)const_cast<void*>(
|
||||||
|
static_cast<const void*>(input_min_vec.flat<float>().data()));
|
||||||
|
const Tensor& input_max_vec = ctx->input(kInputMaxVecIndex);
|
||||||
|
float* input_max_vec_data = (float*)const_cast<void*>(
|
||||||
|
static_cast<const void*>(input_max_vec.flat<float>().data()));
|
||||||
|
|
||||||
|
const Tensor& input_requested_min = ctx->input(this->kRequestMinIndex);
|
||||||
|
const float input_requested_min_float =
|
||||||
|
input_requested_min.flat<float>()(0);
|
||||||
|
const Tensor& input_requested_max = ctx->input(this->kRequestMaxIndex);
|
||||||
|
const float input_requested_max_float =
|
||||||
|
input_requested_max.flat<float>()(0);
|
||||||
|
|
||||||
|
size_t depth = input_min_vec.NumElements();
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_min_vec.dim_size(0) == depth,
|
||||||
|
errors::InvalidArgument("input_min has incorrect size, expected ",
|
||||||
|
depth, " was ", input_min_vec.dim_size(0)));
|
||||||
|
OP_REQUIRES(
|
||||||
|
ctx, input_max_vec.dim_size(0) == depth,
|
||||||
|
errors::InvalidArgument("input_max has incorrect size, expected ",
|
||||||
|
depth, " was ", input_max_vec.dim_size(0)));
|
||||||
|
|
||||||
|
if (out_type_ == DT_QINT8) DCHECK(input_requested_min_float < 0.0f);
|
||||||
|
|
||||||
|
const float factor = (out_type_ == DT_QINT8) ? 127.0f : 255.0f;
|
||||||
|
const float requested_min_max =
|
||||||
|
std::max(std::abs(input_requested_min_float),
|
||||||
|
std::abs(input_requested_max_float));
|
||||||
|
Tensor* output = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(kOutputTensorIndex,
|
||||||
|
input.shape(), &output));
|
||||||
|
|
||||||
|
std::vector<float> scales(depth);
|
||||||
|
for (int i = 0; i < depth; ++i) {
|
||||||
|
float min_max_from_vec = std::max(std::abs(input_min_vec_data[i]),
|
||||||
|
std::abs(input_max_vec_data[i]));
|
||||||
|
scales[i] = factor * (min_max_from_vec / requested_min_max /
|
||||||
|
static_cast<float>(1L << 31));
|
||||||
|
}
|
||||||
|
|
||||||
|
mkldnn::primitive_attr reorder_attr;
|
||||||
|
reorder_attr.set_output_scales(2, scales);
|
||||||
|
|
||||||
|
memory::dims dims_mkl_order =
|
||||||
|
TFShapeToMklDnnDimsInNCHW(input.shape(), FORMAT_NHWC);
|
||||||
|
memory::desc input_md = memory::desc(dims_mkl_order, MklDnnType<qint32>(),
|
||||||
|
memory::format::nhwc);
|
||||||
|
memory::desc output_md =
|
||||||
|
(out_type_ == DT_QINT8)
|
||||||
|
? memory::desc(dims_mkl_order, MklDnnType<qint8>(),
|
||||||
|
memory::format::nhwc)
|
||||||
|
: memory::desc(dims_mkl_order, MklDnnType<quint8>(),
|
||||||
|
memory::format::nhwc);
|
||||||
|
|
||||||
|
memory::primitive_desc input_pd =
|
||||||
|
memory::primitive_desc(input_md, cpu_engine_);
|
||||||
|
memory::primitive_desc output_pd =
|
||||||
|
memory::primitive_desc(output_md, cpu_engine_);
|
||||||
|
|
||||||
|
void* input_buf =
|
||||||
|
static_cast<void*>(const_cast<qint32*>(input.flat<qint32>().data()));
|
||||||
|
void* output_buf;
|
||||||
|
if (out_type_ == DT_QINT8) {
|
||||||
|
output_buf = static_cast<void*>(
|
||||||
|
const_cast<qint8*>(output->flat<qint8>().data()));
|
||||||
|
} else {
|
||||||
|
output_buf = static_cast<void*>(
|
||||||
|
const_cast<quint8*>(output->flat<quint8>().data()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<memory> input_mem_prim_(new memory(input_pd, input_buf));
|
||||||
|
std::unique_ptr<memory> output_mem_prim_(
|
||||||
|
new memory(output_pd, output_buf));
|
||||||
|
|
||||||
|
mkldnn::reorder::primitive_desc reorder_pd =
|
||||||
|
mkldnn::reorder::primitive_desc(input_pd, output_pd, reorder_attr);
|
||||||
|
std::vector<mkldnn::primitive> net;
|
||||||
|
net.push_back(
|
||||||
|
mkldnn::reorder(reorder_pd, *input_mem_prim_, *output_mem_prim_));
|
||||||
|
stream(stream::kind::eager).submit(net).wait();
|
||||||
|
|
||||||
|
Tensor* output_min = nullptr;
|
||||||
|
Tensor* output_max = nullptr;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(kOutputMinIndex, {}, &output_min));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->allocate_output(kOutputMaxIndex, {}, &output_max));
|
||||||
|
|
||||||
|
output_min->flat<float>()(0) = input_requested_min_float;
|
||||||
|
output_max->flat<float>()(0) = input_requested_max_float;
|
||||||
|
} catch (mkldnn::error& e) {
|
||||||
|
string error_msg = "Status: " + std::to_string(e.status) +
|
||||||
|
", message: " + std::string(e.message) + ", in file " +
|
||||||
|
std::string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, errors::Aborted("Operation received an exception:", error_msg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int kInputTensorIndex = 0;
|
||||||
|
const int kInputMinVecIndex = 1;
|
||||||
|
const int kInputMaxVecIndex = 2;
|
||||||
|
const int kRequestMinIndex = 3;
|
||||||
|
const int kRequestMaxIndex = 4;
|
||||||
|
const int kOutputTensorIndex = 0;
|
||||||
|
const int kOutputMinIndex = 1;
|
||||||
|
const int kOutputMaxIndex = 2;
|
||||||
|
DataType out_type_;
|
||||||
|
engine cpu_engine_ = engine(engine::cpu, 0);
|
||||||
|
};
|
||||||
|
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("RequantizePerChannel")
|
||||||
|
.Device(DEVICE_CPU)
|
||||||
|
.TypeConstraint<qint32>("T")
|
||||||
|
.TypeConstraint<qint8>("out_type"),
|
||||||
|
MklRequantizePerChannelOp<CPUDevice, qint8>);
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // INTEL_MKL
|
@ -1749,6 +1749,45 @@ inputs: Must all be the same size and shape.
|
|||||||
|
|
||||||
#endif // INTEL_MKL
|
#endif // INTEL_MKL
|
||||||
|
|
||||||
|
REGISTER_OP("RequantizePerChannel")
|
||||||
|
.Input("input: T")
|
||||||
|
.Input("input_min: float")
|
||||||
|
.Input("input_max: float")
|
||||||
|
.Input("requested_output_min: float")
|
||||||
|
.Input("requested_output_max: float")
|
||||||
|
.Output("output: out_type")
|
||||||
|
.Output("output_min: float")
|
||||||
|
.Output("output_max: float")
|
||||||
|
.Attr("T: quantizedtype = DT_QINT32")
|
||||||
|
.Attr("out_type: quantizedtype = DT_QUINT8")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
|
||||||
|
c->set_output(1, c->Scalar());
|
||||||
|
c->set_output(2, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
REGISTER_OP("RequantizationRangePerChannel")
|
||||||
|
.Input("input: T")
|
||||||
|
.Input("input_min: float")
|
||||||
|
.Input("input_max: float")
|
||||||
|
.Output("output_min: float")
|
||||||
|
.Output("output_max: float")
|
||||||
|
.Attr("T: quantizedtype = DT_QINT32")
|
||||||
|
.Attr("clip_value_max: float")
|
||||||
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
|
ShapeHandle unused;
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
|
||||||
|
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
c->set_output(1, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("NextAfter")
|
REGISTER_OP("NextAfter")
|
||||||
.Attr("T: {float64, float32} = DT_FLOAT")
|
.Attr("T: {float64, float32} = DT_FLOAT")
|
||||||
.Input("x1: T")
|
.Input("x1: T")
|
||||||
|
@ -2644,10 +2644,18 @@ tf_module {
|
|||||||
name: "RequantizationRange"
|
name: "RequantizationRange"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RequantizationRangePerChannel"
|
||||||
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Requantize"
|
name: "Requantize"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RequantizePerChannel"
|
||||||
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Reshape"
|
name: "Reshape"
|
||||||
argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
@ -2644,10 +2644,18 @@ tf_module {
|
|||||||
name: "RequantizationRange"
|
name: "RequantizationRange"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_min\', \'input_max\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RequantizationRangePerChannel"
|
||||||
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'clip_value_max\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Requantize"
|
name: "Requantize"
|
||||||
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "RequantizePerChannel"
|
||||||
|
argspec: "args=[\'input\', \'input_min\', \'input_max\', \'requested_output_min\', \'requested_output_max\', \'out_type\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "Reshape"
|
name: "Reshape"
|
||||||
argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
argspec: "args=[\'tensor\', \'shape\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
Loading…
Reference in New Issue
Block a user