Removing OpenMP dependency from Mkl-dnn supporting threadpool

This commit is contained in:
sshiddib 2020-05-26 15:03:42 -07:00 committed by Sharada Shiddibhavi
parent be46769cee
commit 4da2360572
6 changed files with 64 additions and 7 deletions

View File

@ -25,7 +25,7 @@ limitations under the License.
namespace tensorflow {
#ifdef _OPENMP
#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
SessionOptions options;
unsetenv("OMP_NUM_THREADS");
@ -46,7 +46,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
EXPECT_EQ(omp_get_max_threads(), 314);
}
#endif // _OPENMP
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
} // namespace tensorflow

View File

@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/core/public/session.h"
#if defined(INTEL_MKL_DNN_ONLY)
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_util.h"
#endif

View File

@ -91,12 +91,15 @@ limitations under the License.
// https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
#ifdef INTEL_MKL
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
#include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_threadpool.h"
#include "tensorflow/core/util/work_sharder.h"
namespace {
enum {
@ -428,6 +431,26 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
((max_input - min_input) *
std::max(std::abs(max_weight), std::abs(min_weight)));
#ifdef ENABLE_MKLDNN_THREADPOOL
auto parallel_func = [&](int64 start, int64 end) {
for (int64 j = start ; j < end; j++) {
int x = 0;
for (int64 i = 0; i < k; ++i) {
x += wt_buf[i * n + j];
}
comp_bias[j] =
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
}
};
const float kArithCost = 2.5f;
const float kMovCost = 1.0f;
float shard_cost = 4*kArithCost + kMovCost;
const DeviceBase::CpuWorkerThreads& worker_threads =
*(context->device()->tensorflow_cpu_worker_threads());
Shard(worker_threads.num_threads, worker_threads.workers, n, shard_cost,
parallel_func);
#else
#pragma omp parallel for schedule(static)
for (int j = 0; j < n; ++j) {
int x = 0;
@ -437,7 +460,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
comp_bias[j] =
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
}
#endif // ENABLE_MKLDNN_THREADPOOL
return reinterpret_cast<Tbias*>(comp_bias_);
} else if (mode_ == QUANTIZE_MODE_SCALED) {

View File

@ -69,6 +69,19 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
const float* max_b = max_b_vector.flat<float>().data();
float* min_c = (*min_c_vector)->flat<float>().data();
float* max_c = (*max_c_vector)->flat<float>().data();
#ifdef ENABLE_MKLDNN_THREADPOOL
// TODO: Add eigen parallel_for
for(size_t n = 0; n < n_channel; ++n) {
float a_float_for_one_quant_level =
MklFloatForOneQuantizedLevel<T1>(min_a, max_a);
float b_float_for_one_quant_level =
MklFloatForOneQuantizedLevel<T2>(min_b[n], max_b[n]);
float c_float_for_one_quant_level =
a_float_for_one_quant_level * b_float_for_one_quant_level;
min_c[n] = c_float_for_one_quant_level * c_lowest;
max_c[n] = c_float_for_one_quant_level * c_highest;
}
#else
#pragma omp parallel for
for (size_t n = 0; n < n_channel; ++n) {
float a_float_for_one_quant_level =
@ -80,6 +93,7 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
min_c[n] = c_float_for_one_quant_level * c_lowest;
max_c[n] = c_float_for_one_quant_level * c_highest;
}
#endif // ENABLE_MKLDNN_THREADPOOL
}
} // namespace tensorflow

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/meta_support.h"
#include "tensorflow/core/kernels/no_op.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_threadpool.h"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
@ -73,6 +74,26 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
// Find the ranges of each channel in parallel.
float out_min_max = std::numeric_limits<float>::min();
#ifdef ENABLE_MKLDNN_THREADPOOL
// TODO: Add eigen parallel_for
for(size_t i = 0; i < depth; ++i) {
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
transposed_input.chip<0>(i).minimum();
Eigen::Tensor<qint32, 0, Eigen::RowMajor> max =
transposed_input.chip<0>(i).maximum();
const int32_t min_per_channel = min();
const int32_t max_per_channel = max();
const int32_t abs_max =
std::max(std::abs(min_per_channel), std::abs(max_per_channel));
float scale =
std::max(std::abs(input_min_data[i]), std::abs(input_max_data[i]));
ranges[i] =
scale * static_cast<float>(abs_max) / static_cast<float>(1L << 31);
if (min_per_channel < 0) is_non_negative = false;
out_min_max = std::max(out_min_max, ranges[i]);
}
#else
#pragma omp parallel for reduction(max : out_min_max)
for (size_t i = 0; i < depth; ++i) {
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
@ -92,6 +113,7 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
// Thread-local out_min_max.
out_min_max = std::max(out_min_max, ranges[i]);
}
#endif // ENABLE_MKLDNN_THREADPOOL
// All local out_min_max gets max-reduced into one global out_min_max at
// the end of the loop by specifying reduction(max:out_min_max) along with
// omp parallel for.

View File

@ -354,9 +354,7 @@ def tf_copts(
)
def tf_openmp_copts():
# TODO(intel-mkl): Remove -fopenmp for threadpool after removing all
# omp pragmas in tensorflow/core.
return if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fopenmp"])
return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"]))
def tfe_xla_copts():
return select({