Internal change

PiperOrigin-RevId: 361090132
Change-Id: Ia8e3efb6e794d83c4ef506fccf830370ec6780d0
This commit is contained in:
A. Unique TensorFlower 2021-03-05 00:32:39 -08:00 committed by TensorFlower Gardener
parent f0359d50ea
commit 460e000de3
26 changed files with 67 additions and 76 deletions

View File

@ -172,20 +172,13 @@ build:mkl -c opt
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_threadpool --define=build_with_mkl_opensource=true
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
build:mkl_threadpool -c opt
# Config setting to build with oneDNN and without the binary blob
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_opensource_only --define=build_with_mkl_opensource=true
build:mkl_opensource_only --define=build_with_openmp=true
build:mkl_opensource_only -c opt
# Config setting to build with oneDNN for Arm.
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
build:mkl_aarch64 --define=build_with_mkl_opensource=true
build:mkl_aarch64 --define=build_with_openmp=true
build:mkl_aarch64 -c opt
# This config refers to building with CUDA available. It does not necessarily

View File

@ -24,7 +24,7 @@ limitations under the License.
namespace tensorflow {
#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
#if defined(_OPENMP) && defined(ENABLE_ONEDNN_OPENMP)
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
SessionOptions options;
unsetenv("OMP_NUM_THREADS");
@ -36,7 +36,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht);
}
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
#endif // defined(_OPENMP) && defined(ENABLE_ONEDNN_OPENMP)
} // namespace tensorflow

View File

@ -101,7 +101,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
return (val && strings::safe_strto32(val, &num)) ? num : 0;
}
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#if defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
int32 OMPThreadsFromEnvironment() {
// 1) std::getenv is thread-safe (as long as no other function modifies the
// host env) from C++11 onward. 2) Most of TF code (except tests and
@ -121,14 +121,14 @@ int32 DefaultNumIntraOpThreads() {
// Default to the maximum parallelism for the current process.
return port::MaxParallelism();
}
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 inter_op = options.config.inter_op_parallelism_threads();
if (inter_op > 0) return inter_op;
const int32 env_inter_op = GetEnvNumInterOpThreads();
if (env_inter_op > 0) return env_inter_op;
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#if defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
if (!DisableMKL()) {
// MKL library executes ops in parallel using OMP threads.
// Setting inter_op conservatively to avoid thread oversubscription that
@ -149,7 +149,7 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
<< ". Tune using inter_op_parallelism_threads for best performance.";
return mkl_inter_op;
}
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
return DefaultNumInterOpThreads();
}

View File

@ -50,7 +50,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
name, DEVICE_CPU, memory_limit, locality)),
allocator_(allocator),
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#if defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
// Early return when MKL is disabled
if (DisableMKL()) return;
#ifdef _OPENMP
@ -65,7 +65,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
(mkl_intra_op + ht - 1) / ht);
}
#endif // _OPENMP
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
}
ThreadPoolDevice::~ThreadPoolDevice() {}

View File

@ -280,7 +280,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
std::shared_ptr<stream> fwd_stream) {
DCHECK_EQ(in_data.size(), context_.data_mem.size());
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.data_mem_shdptr[i]->set_data_handle(
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
}
@ -292,7 +292,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
}
context_.dst_mem->set_data_handle(
static_cast<void*>(dst_data.get_data_handle()));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
context_.data_mem[i] = *context_.data_mem_shdptr[i];

View File

@ -104,7 +104,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_bias_data, const T* diff_dst_data,
std::shared_ptr<stream> bwd_filter_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
// TODO: Create a common function and avoid the duplicate code
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream);
@ -129,7 +129,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
}
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
context_.bwd_filter_primitives_args);

View File

@ -106,7 +106,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
void Execute(const T* diff_src_data, const T* filter_data,
const T* diff_dst_data,
std::shared_ptr<stream> bwd_input_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
// TODO: Create a common function and avoid the duplicate code
context_.diff_src_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_src_data)), *bwd_input_stream);
@ -121,7 +121,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
static_cast<T*>(const_cast<T*>(filter_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
context_.bwd_input_primitives_args);

View File

@ -114,7 +114,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Tbias* bias_data, const Toutput* dst_data,
std::shared_ptr<stream> fwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
// TODO: Create a common function and avoid the duplicate code
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
@ -137,7 +137,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
}
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());

View File

@ -78,7 +78,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
void Execute(const T* src_data, const U* weights_data, T* dst_data,
U* mean_data, U* variance_data,
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
// TODO: Create a common function and avoid the duplicate code
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
@ -116,7 +116,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
if (workspace_data != nullptr) {
context_.ws_mem->set_data_handle(workspace_data);
}
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
// Execute batch-normalization forward primitives.
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
@ -422,7 +422,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
U* diff_weights_data, U* res_space_data,
std::shared_ptr<stream> bwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
// TODO: Create a common function and avoid the duplicate code
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
@ -460,7 +460,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
}
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
// Execute backward batch-normalization primitives.
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);

View File

@ -1211,7 +1211,7 @@ INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
// This test is flaky for --config=mkl_threadpool (The supposedly cached op
// sometimes took longer than even 0.9 * original_time.)
// TODO(intel-tf): Re-enable the test for --config=mkl_threadpool.
#ifndef ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_ONEDNN_OPENMP
// Test the performance of MklFusedMatMul weight cache.
// For the first time B matrix will be reordered and cached which will be
// used for subsequent runs
@ -1314,7 +1314,7 @@ TEST_F(MklFusedMatMulCacheTest, WeightCached) {
test::ExpectTensorNear<float>(expected, output_new, 1e-5);
}
}
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // ENABLE_ONEDNN_OPENMP
class BiasCacheTest : public OpsTestBase {
public:

View File

@ -155,14 +155,14 @@ class MklMatMulOp : public OpKernel {
char char_transa = transa ? 'T' : 'N';
char char_transb = transb ? 'T' : 'N';
VLOG(2) << "MKL DNN SGEMM called";
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
MklDnnThreadPool eigen_tp(ctx);
dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb,
beta, c, ldc, &eigen_tp);
#else
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc);
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
}
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,

View File

@ -95,7 +95,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
void Execute(const Tinput* src_data, const Tweight* weight_data,
const Tbias* bias_data, Toutput* dst_data,
std::shared_ptr<stream> fwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
context_.weight_mem->set_data_handle(
@ -112,7 +112,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<Tbias*>(bias_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
@ -534,7 +534,7 @@ class MklMatMulPrimitive : public MklPrimitive {
void Execute(const T* a_data, const T* b_data, T* c_data,
std::shared_ptr<stream> stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
*stream);
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
@ -545,7 +545,7 @@ class MklMatMulPrimitive : public MklPrimitive {
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.matmul_primitives, stream, context_.net_args);
// After execution, set data handle back

View File

@ -86,7 +86,7 @@ template <typename T>
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
void* ws_data,
std::shared_ptr<stream> fwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), *fwd_stream);
@ -106,7 +106,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
// Set back data handle.
@ -188,7 +188,7 @@ template <typename T>
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
T* diff_src_data, const void* ws_data,
std::shared_ptr<stream> bwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
@ -205,7 +205,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);

View File

@ -431,7 +431,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
((max_input - min_input) *
std::max(std::abs(max_weight), std::abs(min_weight)));
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
auto parallel_func = [&](int64 start, int64 end) {
for (int64 j = start; j < end; j++) {
int x = 0;
@ -460,7 +460,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
comp_bias[j] =
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
}
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
return reinterpret_cast<Tbias*>(comp_bias_);
} else if (mode_ == QUANTIZE_MODE_SCALED) {

View File

@ -87,13 +87,13 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
void Execute(void* src_data, void* dst_data,
std::shared_ptr<stream> reorder_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(src_data, *reorder_stream);
context_.dst_mem->set_data_handle(dst_data, *reorder_stream);
#else
context_.src_mem->set_data_handle(src_data);
context_.dst_mem->set_data_handle(dst_data);
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
context_.reorder_prim->execute(*reorder_stream, context_.prim_args);
// After execution, set data handle back.
context_.src_mem->set_data_handle(DummyData);

View File

@ -70,9 +70,9 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
float* min_c = (*min_c_vector)->flat<float>().data();
float* max_c = (*max_c_vector)->flat<float>().data();
#ifndef ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_ONEDNN_OPENMP
#pragma omp parallel for
#endif // !ENABLE_MKLDNN_THREADPOOL
#endif // ENABLE_ONEDNN_OPENMP
// TODO: Add eigen parallel_for
for (int64_t n = 0; n < n_channel; ++n) {
float a_float_for_one_quant_level =

View File

@ -74,7 +74,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
@ -83,7 +83,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
execute_primitives(context_.fwd_primitives, fwd_stream,
@ -255,7 +255,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// diff_src_data: output data buffer of diff_src
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
std::shared_ptr<stream> bwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
context_.diff_dst_mem->set_data_handle(
@ -268,7 +268,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
DCHECK_EQ(context_.bwd_primitives.size(),
context_.bwd_primitives_args.size());
execute_primitives(context_.bwd_primitives, bwd_stream,

View File

@ -76,13 +76,13 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
// Find the ranges of each channel in parallel.
float out_min_max = std::numeric_limits<float>::min();
#ifndef ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_ONEDNN_OPENMP
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for reduction(max : out_min_max)
#endif
#endif // !ENABLE_MKLDNN_THREADPOOL
#endif // ENABLE_ONEDNN_OPENMP
// TODO: Add eigen parallel_for
for (int64_t i = 0; i < depth; ++i) {
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =

View File

@ -185,7 +185,7 @@ class MklSlicePrimitive : public MklPrimitive {
void Execute(const MklSliceParams& sliceParams,
std::shared_ptr<stream> slice_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
*slice_stream);
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
@ -193,7 +193,7 @@ class MklSlicePrimitive : public MklPrimitive {
#else
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args);

View File

@ -58,7 +58,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_cpu_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
@ -67,7 +67,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,

View File

@ -32,7 +32,7 @@ limitations under the License.
namespace tensorflow {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
using dnnl::stream_attr;
using dnnl::threadpool_iface;
@ -116,7 +116,7 @@ struct MklDnnThreadPool {
MklDnnThreadPool(OpKernelContext* ctx) {}
};
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
} // namespace tensorflow

View File

@ -224,7 +224,7 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) {
inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
const engine& engine) {
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
stream_attr tp_stream_attr(engine::kind::cpu);
if (eigen_tp != nullptr) {
tp_stream_attr.set_threadpool(eigen_tp);
@ -238,7 +238,7 @@ inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
#else
stream* tp_stream = new stream(engine);
return tp_stream;
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
}
class MklDnnShape {
@ -1390,11 +1390,11 @@ class MklDnnData {
std::shared_ptr<stream> t_stream = nullptr) {
CHECK_NOTNULL(user_memory_);
CHECK_NOTNULL(data_buffer);
#ifdef ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_ONEDNN_OPENMP
user_memory_->set_data_handle(data_buffer, *t_stream);
#else
user_memory_->set_data_handle(data_buffer);
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // !ENABLE_ONEDNN_OPENMP
}
/// Set function for data buffer of user memory primitive.

View File

@ -39,7 +39,7 @@ load(
load(
"//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
"if_mkldnn_threadpool",
"if_mkldnn_openmp",
)
load("@bazel_skylib//lib:new_sets.bzl", "sets")
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
@ -364,7 +364,7 @@ def tf_copts(
if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) +
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
if_mkl(["-DINTEL_MKL=1"]) +
if_mkldnn_threadpool(["-DENABLE_MKLDNN_THREADPOOL"]) +
if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) +
if_enable_mkl(["-DENABLE_MKL"]) +
if_android_arm(["-mfpu=neon"]) +
if_linux_x86_64(["-msse3"]) +

View File

@ -19,11 +19,10 @@ config_setting(
)
config_setting(
name = "build_with_mkldnn_threadpool",
name = "build_with_mkldnn_openmp",
define_values = {
"build_with_mkl": "true",
"build_with_mkl_opensource": "true",
"build_with_mkldnn_threadpool": "true",
"build_with_openmp": "true",
},
visibility = ["//visibility:public"],
)

View File

@ -14,18 +14,18 @@ def if_mkl_open_source_only(if_true, if_false = []):
"//conditions:default": if_false,
})
def if_mkldnn_threadpool(if_true, if_false = []):
"""Returns `if_true` if MKL-DNN v1.x is used.
def if_mkldnn_openmp(if_true, if_false = []):
"""Returns `if_true` if OpenMP is used with oneDNN.
Shorthand for select()'ing on whether we're building with
MKL-DNN v1.x open source library only with user specified threadpool, without depending on MKL binary form.
oneDNN open source library only with openmp
Returns a select statement which evaluates to if_true if we're building
with MKL-DNN v1.x open source library only with user specified threadpool. Otherwise, the
with oneDNN open source library only with OpenMP. Otherwise, the
select statement evaluates to if_false.
"""
return select({
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": if_true,
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_openmp": if_true,
"//conditions:default": if_false,
})

View File

@ -10,8 +10,7 @@ load(
)
load(
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
"if_mkl_open_source_only",
"if_mkldnn_threadpool",
"if_mkldnn_openmp",
)
load(
"@org_tensorflow//third_party/mkl:build_defs.bzl",
@ -45,8 +44,8 @@ template_rule(
src = "include/dnnl_config.h.in",
out = "include/dnnl_config.h",
substitutions = select({
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": _DNNL_RUNTIME_THREADPOOL,
"@org_tensorflow//third_party/mkl:build_with_mkl": _DNNL_RUNTIME_OMP,
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_openmp": _DNNL_RUNTIME_OMP,
"@org_tensorflow//third_party/mkl:build_with_mkl": _DNNL_RUNTIME_THREADPOOL,
"//conditions:default": _DNNL_RUNTIME_SEQ,
}),
)