Merge pull request #39893 from Intel-tensorflow:sshiddib/dnnl_threadpool_openmp
PiperOrigin-RevId: 315516788 Change-Id: Id473c4c2589d5650f488dfa663031879ca8f80bb
This commit is contained in:
commit
4ff7e65477
@ -25,7 +25,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
#ifdef _OPENMP
|
#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
||||||
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
unsetenv("OMP_NUM_THREADS");
|
unsetenv("OMP_NUM_THREADS");
|
||||||
@ -46,7 +46,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) {
|
|||||||
|
|
||||||
EXPECT_EQ(omp_get_max_threads(), 314);
|
EXPECT_EQ(omp_get_max_threads(), 314);
|
||||||
}
|
}
|
||||||
#endif // _OPENMP
|
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -28,7 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/public/session.h"
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
#if defined(INTEL_MKL_DNN_ONLY)
|
#if defined(INTEL_MKL_DNN_ONLY)
|
||||||
#include "third_party/intel_mkl_dnn/include/mkldnn.h"
|
#include "mkldnn.hpp"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
@ -91,12 +91,15 @@ limitations under the License.
|
|||||||
// https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
|
// https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/kernels/fill_functor.h"
|
#include "tensorflow/core/kernels/fill_functor.h"
|
||||||
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
#include "tensorflow/core/kernels/mkl_matmul_ops_common.h"
|
||||||
#include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
|
#include "tensorflow/core/kernels/mkl_quantized_conv_ops.h"
|
||||||
#include "tensorflow/core/kernels/no_op.h"
|
#include "tensorflow/core/kernels/no_op.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/mkl_threadpool.h"
|
||||||
|
#include "tensorflow/core/util/work_sharder.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
enum {
|
enum {
|
||||||
@ -437,6 +440,26 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
((max_input - min_input) *
|
((max_input - min_input) *
|
||||||
std::max(std::abs(max_weight), std::abs(min_weight)));
|
std::max(std::abs(max_weight), std::abs(min_weight)));
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
auto parallel_func = [&](int64 start, int64 end) {
|
||||||
|
for (int64 j = start; j < end; j++) {
|
||||||
|
int x = 0;
|
||||||
|
for (int64 i = 0; i < k; ++i) {
|
||||||
|
x += wt_buf[i * n + j];
|
||||||
|
}
|
||||||
|
comp_bias[j] =
|
||||||
|
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
const float kArithCost = 2.5f;
|
||||||
|
const float kMovCost = 1.0f;
|
||||||
|
float shard_cost = 4 * kArithCost + kMovCost;
|
||||||
|
const DeviceBase::CpuWorkerThreads& worker_threads =
|
||||||
|
*(context->device()->tensorflow_cpu_worker_threads());
|
||||||
|
Shard(worker_threads.num_threads, worker_threads.workers, n, shard_cost,
|
||||||
|
parallel_func);
|
||||||
|
#else
|
||||||
#pragma omp parallel for schedule(static)
|
#pragma omp parallel for schedule(static)
|
||||||
for (int j = 0; j < n; ++j) {
|
for (int j = 0; j < n; ++j) {
|
||||||
int x = 0;
|
int x = 0;
|
||||||
@ -446,7 +469,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
comp_bias[j] =
|
comp_bias[j] =
|
||||||
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
|
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
|
||||||
}
|
}
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
return reinterpret_cast<Tbias*>(comp_bias_);
|
return reinterpret_cast<Tbias*>(comp_bias_);
|
||||||
|
|
||||||
} else if (mode_ == QUANTIZE_MODE_SCALED) {
|
} else if (mode_ == QUANTIZE_MODE_SCALED) {
|
||||||
|
@ -69,7 +69,11 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
|
|||||||
const float* max_b = max_b_vector.flat<float>().data();
|
const float* max_b = max_b_vector.flat<float>().data();
|
||||||
float* min_c = (*min_c_vector)->flat<float>().data();
|
float* min_c = (*min_c_vector)->flat<float>().data();
|
||||||
float* max_c = (*max_c_vector)->flat<float>().data();
|
float* max_c = (*max_c_vector)->flat<float>().data();
|
||||||
|
|
||||||
|
#ifndef ENABLE_MKLDNN_THREADPOOL
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
|
#endif // !ENABLE_MKLDNN_THREADPOOL
|
||||||
|
// TODO: Add eigen parallel_for
|
||||||
for (size_t n = 0; n < n_channel; ++n) {
|
for (size_t n = 0; n < n_channel; ++n) {
|
||||||
float a_float_for_one_quant_level =
|
float a_float_for_one_quant_level =
|
||||||
MklFloatForOneQuantizedLevel<T1>(min_a, max_a);
|
MklFloatForOneQuantizedLevel<T1>(min_a, max_a);
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
#define EIGEN_USE_THREADS
|
#define EIGEN_USE_THREADS
|
||||||
|
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
@ -28,6 +29,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/kernels/meta_support.h"
|
#include "tensorflow/core/kernels/meta_support.h"
|
||||||
#include "tensorflow/core/kernels/no_op.h"
|
#include "tensorflow/core/kernels/no_op.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/mkl_threadpool.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -73,7 +75,11 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
|
|||||||
|
|
||||||
// Find the ranges of each channel in parallel.
|
// Find the ranges of each channel in parallel.
|
||||||
float out_min_max = std::numeric_limits<float>::min();
|
float out_min_max = std::numeric_limits<float>::min();
|
||||||
|
|
||||||
|
#ifndef ENABLE_MKLDNN_THREADPOOL
|
||||||
#pragma omp parallel for reduction(max : out_min_max)
|
#pragma omp parallel for reduction(max : out_min_max)
|
||||||
|
#endif // !ENABLE_MKLDNN_THREADPOOL
|
||||||
|
// TODO: Add eigen parallel_for
|
||||||
for (size_t i = 0; i < depth; ++i) {
|
for (size_t i = 0; i < depth; ++i) {
|
||||||
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
|
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
|
||||||
transposed_input.chip<0>(i).minimum();
|
transposed_input.chip<0>(i).minimum();
|
||||||
@ -92,6 +98,7 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
|
|||||||
// Thread-local out_min_max.
|
// Thread-local out_min_max.
|
||||||
out_min_max = std::max(out_min_max, ranges[i]);
|
out_min_max = std::max(out_min_max, ranges[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// All local out_min_max gets max-reduced into one global out_min_max at
|
// All local out_min_max gets max-reduced into one global out_min_max at
|
||||||
// the end of the loop by specifying reduction(max:out_min_max) along with
|
// the end of the loop by specifying reduction(max:out_min_max) along with
|
||||||
// omp parallel for.
|
// omp parallel for.
|
||||||
|
@ -354,9 +354,7 @@ def tf_copts(
|
|||||||
)
|
)
|
||||||
|
|
||||||
def tf_openmp_copts():
|
def tf_openmp_copts():
|
||||||
# TODO(intel-mkl): Remove -fopenmp for threadpool after removing all
|
return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"]))
|
||||||
# omp pragmas in tensorflow/core.
|
|
||||||
return if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fopenmp"])
|
|
||||||
|
|
||||||
def tfe_xla_copts():
|
def tfe_xla_copts():
|
||||||
return select({
|
return select({
|
||||||
|
Loading…
Reference in New Issue
Block a user