diff --git a/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc b/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc index 5d583a8360b..c29752d3c2c 100644 --- a/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc +++ b/tensorflow/core/common_runtime/mkl_threadpool_device_test.cc @@ -25,7 +25,7 @@ limitations under the License. namespace tensorflow { -#ifdef _OPENMP +#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL) TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) { SessionOptions options; unsetenv("OMP_NUM_THREADS"); @@ -46,7 +46,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpPreSets) { EXPECT_EQ(omp_get_max_threads(), 314); } -#endif // _OPENMP +#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL) } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc index a055351337c..9d11b0fb006 100644 --- a/tensorflow/core/kernels/mkl_conv_ops_test.cc +++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/core/public/session.h" #if defined(INTEL_MKL_DNN_ONLY) -#include "third_party/intel_mkl_dnn/include/mkldnn.h" +#include "mkldnn.hpp" #include "tensorflow/core/util/mkl_util.h" #endif diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc index cc7127e0559..d8bbc130c55 100644 --- a/tensorflow/core/kernels/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc @@ -91,12 +91,15 @@ limitations under the License. // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training #ifdef INTEL_MKL +#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/mkl_matmul_ops_common.h" #include "tensorflow/core/kernels/mkl_quantized_conv_ops.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_threadpool.h" +#include "tensorflow/core/util/work_sharder.h" namespace { enum { @@ -428,6 +431,26 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { ((max_input - min_input) * std::max(std::abs(max_weight), std::abs(min_weight))); +#ifdef ENABLE_MKLDNN_THREADPOOL + auto parallel_func = [&](int64 start, int64 end) { + for (int64 j = start ; j < end; j++) { + int x = 0; + for (int64 i = 0; i < k; ++i) { + x += wt_buf[i * n + j]; + } + comp_bias[j] = + ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin)); + } + }; + + const float kArithCost = 2.5f; + const float kMovCost = 1.0f; + float shard_cost = 4*kArithCost + kMovCost; + const DeviceBase::CpuWorkerThreads& worker_threads = + *(context->device()->tensorflow_cpu_worker_threads()); + Shard(worker_threads.num_threads, worker_threads.workers, n, shard_cost, + parallel_func); +#else #pragma omp parallel for schedule(static) for (int j = 0; j < n; ++j) { int x = 0; @@ -437,7 +460,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> { comp_bias[j] = ((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin)); } - +#endif // ENABLE_MKLDNN_THREADPOOL return reinterpret_cast<Tbias*>(comp_bias_); } else if (mode_ == QUANTIZE_MODE_SCALED) { diff --git a/tensorflow/core/kernels/mkl_quantized_conv_ops.h b/tensorflow/core/kernels/mkl_quantized_conv_ops.h index fef2d837cf2..037a3a5f3ff 100644 --- a/tensorflow/core/kernels/mkl_quantized_conv_ops.h +++ b/tensorflow/core/kernels/mkl_quantized_conv_ops.h @@ -69,6 +69,19 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a, const float* max_b = max_b_vector.flat<float>().data(); float* min_c = (*min_c_vector)->flat<float>().data(); float* max_c = (*max_c_vector)->flat<float>().data(); +#ifdef ENABLE_MKLDNN_THREADPOOL + // TODO: Add eigen parallel_for + for(size_t n = 0; n < n_channel; ++n) { + float a_float_for_one_quant_level = + MklFloatForOneQuantizedLevel<T1>(min_a, max_a); + float b_float_for_one_quant_level = + MklFloatForOneQuantizedLevel<T2>(min_b[n], max_b[n]); + float c_float_for_one_quant_level = + a_float_for_one_quant_level * b_float_for_one_quant_level; + min_c[n] = c_float_for_one_quant_level * c_lowest; + max_c[n] = c_float_for_one_quant_level * c_highest; + } +#else #pragma omp parallel for for (size_t n = 0; n < n_channel; ++n) { float a_float_for_one_quant_level = @@ -80,6 +93,7 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a, min_c[n] = c_float_for_one_quant_level * c_lowest; max_c[n] = c_float_for_one_quant_level * c_highest; } +#endif // ENABLE_MKLDNN_THREADPOOL } } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc b/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc index 767a6f1c397..0a19573d901 100644 --- a/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc +++ b/tensorflow/core/kernels/mkl_requantization_range_per_channel_op.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/meta_support.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -73,6 +74,26 @@ class MklRequantizationRangePerChannelOp : public OpKernel { // Find the ranges of each channel in parallel. float out_min_max = std::numeric_limits<float>::min(); +#ifdef ENABLE_MKLDNN_THREADPOOL + // TODO: Add eigen parallel_for + for(size_t i = 0; i < depth; ++i) { + Eigen::Tensor<qint32, 0, Eigen::RowMajor> min = + transposed_input.chip<0>(i).minimum(); + Eigen::Tensor<qint32, 0, Eigen::RowMajor> max = + transposed_input.chip<0>(i).maximum(); + const int32_t min_per_channel = min(); + const int32_t max_per_channel = max(); + const int32_t abs_max = + std::max(std::abs(min_per_channel), std::abs(max_per_channel)); + float scale = + std::max(std::abs(input_min_data[i]), std::abs(input_max_data[i])); + ranges[i] = + scale * static_cast<float>(abs_max) / static_cast<float>(1L << 31); + if (min_per_channel < 0) is_non_negative = false; + + out_min_max = std::max(out_min_max, ranges[i]); + } +#else #pragma omp parallel for reduction(max : out_min_max) for (size_t i = 0; i < depth; ++i) { Eigen::Tensor<qint32, 0, Eigen::RowMajor> min = @@ -92,6 +113,7 @@ class MklRequantizationRangePerChannelOp : public OpKernel { // Thread-local out_min_max. out_min_max = std::max(out_min_max, ranges[i]); } +#endif // ENABLE_MKLDNN_THREADPOOL // All local out_min_max gets max-reduced into one global out_min_max at // the end of the loop by specifying reduction(max:out_min_max) along with // omp parallel for. diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 9a780839be3..5dc5877367b 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -354,9 +354,7 @@ def tf_copts( ) def tf_openmp_copts(): - # TODO(intel-mkl): Remove -fopenmp for threadpool after removing all - # omp pragmas in tensorflow/core. - return if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fopenmp"]) + return (if_mkl_lnx_x64(["-fopenmp"]) + if_mkldnn_threadpool(["-fno-openmp"])) def tfe_xla_copts(): return select({