Internal change
PiperOrigin-RevId: 361090132 Change-Id: Ia8e3efb6e794d83c4ef506fccf830370ec6780d0
This commit is contained in:
parent
f0359d50ea
commit
460e000de3
9
.bazelrc
9
.bazelrc
@ -172,20 +172,13 @@ build:mkl -c opt
|
|||||||
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
build:mkl_threadpool --define=build_with_mkl=true --define=enable_mkl=true
|
||||||
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
build:mkl_threadpool --define=build_with_mkl_opensource=true
|
build:mkl_threadpool --define=build_with_mkl_opensource=true
|
||||||
build:mkl_threadpool --define=build_with_mkldnn_threadpool=true
|
|
||||||
build:mkl_threadpool -c opt
|
build:mkl_threadpool -c opt
|
||||||
|
|
||||||
# Config setting to build with oneDNN and without the binary blob
|
|
||||||
build:mkl_opensource_only --define=build_with_mkl=true --define=enable_mkl=true
|
|
||||||
build:mkl_opensource_only --define=tensorflow_mkldnn_contraction_kernel=0
|
|
||||||
build:mkl_opensource_only --define=build_with_mkl_opensource=true
|
|
||||||
build:mkl_opensource_only --define=build_with_openmp=true
|
|
||||||
build:mkl_opensource_only -c opt
|
|
||||||
|
|
||||||
# Config setting to build with oneDNN for Arm.
|
# Config setting to build with oneDNN for Arm.
|
||||||
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
|
build:mkl_aarch64 --define=build_with_mkl_aarch64=true --define=enable_mkl=true
|
||||||
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
|
build:mkl_aarch64 --define=tensorflow_mkldnn_contraction_kernel=0
|
||||||
build:mkl_aarch64 --define=build_with_mkl_opensource=true
|
build:mkl_aarch64 --define=build_with_mkl_opensource=true
|
||||||
|
build:mkl_aarch64 --define=build_with_openmp=true
|
||||||
build:mkl_aarch64 -c opt
|
build:mkl_aarch64 -c opt
|
||||||
|
|
||||||
# This config refers to building with CUDA available. It does not necessarily
|
# This config refers to building with CUDA available. It does not necessarily
|
||||||
|
@ -24,7 +24,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
#if defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
#if defined(_OPENMP) && defined(ENABLE_ONEDNN_OPENMP)
|
||||||
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
||||||
SessionOptions options;
|
SessionOptions options;
|
||||||
unsetenv("OMP_NUM_THREADS");
|
unsetenv("OMP_NUM_THREADS");
|
||||||
@ -36,7 +36,7 @@ TEST(MKLThreadPoolDeviceTest, TestOmpDefaults) {
|
|||||||
EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht);
|
EXPECT_EQ(omp_get_max_threads(), (port::NumSchedulableCPUs() + ht - 1) / ht);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // defined(_OPENMP) && !defined(ENABLE_MKLDNN_THREADPOOL)
|
#endif // defined(_OPENMP) && defined(ENABLE_ONEDNN_OPENMP)
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -101,7 +101,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
|
|||||||
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
|
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
|
||||||
return (val && strings::safe_strto32(val, &num)) ? num : 0;
|
return (val && strings::safe_strto32(val, &num)) ? num : 0;
|
||||||
}
|
}
|
||||||
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#if defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
|
||||||
int32 OMPThreadsFromEnvironment() {
|
int32 OMPThreadsFromEnvironment() {
|
||||||
// 1) std::getenv is thread-safe (as long as no other function modifies the
|
// 1) std::getenv is thread-safe (as long as no other function modifies the
|
||||||
// host env) from C++11 onward. 2) Most of TF code (except tests and
|
// host env) from C++11 onward. 2) Most of TF code (except tests and
|
||||||
@ -121,14 +121,14 @@ int32 DefaultNumIntraOpThreads() {
|
|||||||
// Default to the maximum parallelism for the current process.
|
// Default to the maximum parallelism for the current process.
|
||||||
return port::MaxParallelism();
|
return port::MaxParallelism();
|
||||||
}
|
}
|
||||||
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
|
||||||
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
||||||
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
||||||
if (inter_op > 0) return inter_op;
|
if (inter_op > 0) return inter_op;
|
||||||
const int32 env_inter_op = GetEnvNumInterOpThreads();
|
const int32 env_inter_op = GetEnvNumInterOpThreads();
|
||||||
if (env_inter_op > 0) return env_inter_op;
|
if (env_inter_op > 0) return env_inter_op;
|
||||||
|
|
||||||
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#if defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
|
||||||
if (!DisableMKL()) {
|
if (!DisableMKL()) {
|
||||||
// MKL library executes ops in parallel using OMP threads.
|
// MKL library executes ops in parallel using OMP threads.
|
||||||
// Setting inter_op conservatively to avoid thread oversubscription that
|
// Setting inter_op conservatively to avoid thread oversubscription that
|
||||||
@ -149,7 +149,7 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
|||||||
<< ". Tune using inter_op_parallelism_threads for best performance.";
|
<< ". Tune using inter_op_parallelism_threads for best performance.";
|
||||||
return mkl_inter_op;
|
return mkl_inter_op;
|
||||||
}
|
}
|
||||||
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
|
||||||
return DefaultNumInterOpThreads();
|
return DefaultNumInterOpThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
|||||||
name, DEVICE_CPU, memory_limit, locality)),
|
name, DEVICE_CPU, memory_limit, locality)),
|
||||||
allocator_(allocator),
|
allocator_(allocator),
|
||||||
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
|
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
|
||||||
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#if defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
|
||||||
// Early return when MKL is disabled
|
// Early return when MKL is disabled
|
||||||
if (DisableMKL()) return;
|
if (DisableMKL()) return;
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
@ -65,7 +65,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
|||||||
(mkl_intra_op + ht - 1) / ht);
|
(mkl_intra_op + ht - 1) / ht);
|
||||||
}
|
}
|
||||||
#endif // _OPENMP
|
#endif // _OPENMP
|
||||||
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
#endif // defined(ENABLE_ONEDNN_OPENMP) && defined(INTEL_MKL)
|
||||||
}
|
}
|
||||||
|
|
||||||
ThreadPoolDevice::~ThreadPoolDevice() {}
|
ThreadPoolDevice::~ThreadPoolDevice() {}
|
||||||
|
@ -280,7 +280,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
DCHECK_EQ(in_data.size(), context_.data_mem.size());
|
||||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.data_mem_shdptr[i]->set_data_handle(
|
context_.data_mem_shdptr[i]->set_data_handle(
|
||||||
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
|
static_cast<void*>(in_data[i].get_data_handle()), *fwd_stream);
|
||||||
}
|
}
|
||||||
@ -292,7 +292,7 @@ class MklConcatFwdPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
context_.dst_mem->set_data_handle(
|
context_.dst_mem->set_data_handle(
|
||||||
static_cast<void*>(dst_data.get_data_handle()));
|
static_cast<void*>(dst_data.get_data_handle()));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) {
|
||||||
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
context_.data_mem[i] = *context_.data_mem_shdptr[i];
|
||||||
|
@ -104,7 +104,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
|||||||
void Execute(const T* src_data, const T* diff_filter_data,
|
void Execute(const T* src_data, const T* diff_filter_data,
|
||||||
const T* diff_bias_data, const T* diff_dst_data,
|
const T* diff_bias_data, const T* diff_dst_data,
|
||||||
std::shared_ptr<stream> bwd_filter_stream) {
|
std::shared_ptr<stream> bwd_filter_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Create a common function and avoid the duplicate code
|
// TODO: Create a common function and avoid the duplicate code
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream);
|
||||||
@ -129,7 +129,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
|
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
|
||||||
context_.bwd_filter_primitives_args);
|
context_.bwd_filter_primitives_args);
|
||||||
|
|
||||||
|
@ -106,7 +106,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
|||||||
void Execute(const T* diff_src_data, const T* filter_data,
|
void Execute(const T* diff_src_data, const T* filter_data,
|
||||||
const T* diff_dst_data,
|
const T* diff_dst_data,
|
||||||
std::shared_ptr<stream> bwd_input_stream) {
|
std::shared_ptr<stream> bwd_input_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Create a common function and avoid the duplicate code
|
// TODO: Create a common function and avoid the duplicate code
|
||||||
context_.diff_src_mem->set_data_handle(
|
context_.diff_src_mem->set_data_handle(
|
||||||
static_cast<T*>(const_cast<T*>(diff_src_data)), *bwd_input_stream);
|
static_cast<T*>(const_cast<T*>(diff_src_data)), *bwd_input_stream);
|
||||||
@ -121,7 +121,7 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
|||||||
static_cast<T*>(const_cast<T*>(filter_data)));
|
static_cast<T*>(const_cast<T*>(filter_data)));
|
||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
static_cast<T*>(const_cast<T*>(diff_dst_data)));
|
static_cast<T*>(const_cast<T*>(diff_dst_data)));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
|
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
|
||||||
context_.bwd_input_primitives_args);
|
context_.bwd_input_primitives_args);
|
||||||
|
|
||||||
|
@ -114,7 +114,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||||
const Tbias* bias_data, const Toutput* dst_data,
|
const Tbias* bias_data, const Toutput* dst_data,
|
||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Create a common function and avoid the duplicate code
|
// TODO: Create a common function and avoid the duplicate code
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||||
@ -137,7 +137,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
context_.dst_mem->set_data_handle(
|
context_.dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||||
context_.fwd_primitives_args.size());
|
context_.fwd_primitives_args.size());
|
||||||
|
@ -78,7 +78,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
||||||
U* mean_data, U* variance_data,
|
U* mean_data, U* variance_data,
|
||||||
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
|
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Create a common function and avoid the duplicate code
|
// TODO: Create a common function and avoid the duplicate code
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
||||||
@ -116,7 +116,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
|||||||
if (workspace_data != nullptr) {
|
if (workspace_data != nullptr) {
|
||||||
context_.ws_mem->set_data_handle(workspace_data);
|
context_.ws_mem->set_data_handle(workspace_data);
|
||||||
}
|
}
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
// Execute batch-normalization forward primitives.
|
// Execute batch-normalization forward primitives.
|
||||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||||
@ -422,7 +422,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
|||||||
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
|
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
|
||||||
U* diff_weights_data, U* res_space_data,
|
U* diff_weights_data, U* res_space_data,
|
||||||
std::shared_ptr<stream> bwd_stream) {
|
std::shared_ptr<stream> bwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Create a common function and avoid the duplicate code
|
// TODO: Create a common function and avoid the duplicate code
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
|
||||||
@ -460,7 +460,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
|
|
||||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
// Execute backward batch-normalization primitives.
|
// Execute backward batch-normalization primitives.
|
||||||
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
|
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
|
||||||
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||||
|
@ -1211,7 +1211,7 @@ INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
|
|||||||
// This test is flaky for --config=mkl_threadpool (The supposedly cached op
|
// This test is flaky for --config=mkl_threadpool (The supposedly cached op
|
||||||
// sometimes took longer than even 0.9 * original_time.)
|
// sometimes took longer than even 0.9 * original_time.)
|
||||||
// TODO(intel-tf): Re-enable the test for --config=mkl_threadpool.
|
// TODO(intel-tf): Re-enable the test for --config=mkl_threadpool.
|
||||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
#ifdef ENABLE_ONEDNN_OPENMP
|
||||||
// Test the performance of MklFusedMatMul weight cache.
|
// Test the performance of MklFusedMatMul weight cache.
|
||||||
// For the first time B matrix will be reordered and cached which will be
|
// For the first time B matrix will be reordered and cached which will be
|
||||||
// used for subsequent runs
|
// used for subsequent runs
|
||||||
@ -1314,7 +1314,7 @@ TEST_F(MklFusedMatMulCacheTest, WeightCached) {
|
|||||||
test::ExpectTensorNear<float>(expected, output_new, 1e-5);
|
test::ExpectTensorNear<float>(expected, output_new, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
class BiasCacheTest : public OpsTestBase {
|
class BiasCacheTest : public OpsTestBase {
|
||||||
public:
|
public:
|
||||||
|
@ -155,14 +155,14 @@ class MklMatMulOp : public OpKernel {
|
|||||||
char char_transa = transa ? 'T' : 'N';
|
char char_transa = transa ? 'T' : 'N';
|
||||||
char char_transb = transb ? 'T' : 'N';
|
char char_transb = transb ? 'T' : 'N';
|
||||||
VLOG(2) << "MKL DNN SGEMM called";
|
VLOG(2) << "MKL DNN SGEMM called";
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
MklDnnThreadPool eigen_tp(ctx);
|
MklDnnThreadPool eigen_tp(ctx);
|
||||||
dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb,
|
dnnl_sgemm_tp(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb,
|
||||||
beta, c, ldc, &eigen_tp);
|
beta, c, ldc, &eigen_tp);
|
||||||
#else
|
#else
|
||||||
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
dnnl_sgemm(char_transa, char_transb, m, n, k, alpha, a, lda, b, ldb, beta,
|
||||||
c, ldc);
|
c, ldc);
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
}
|
}
|
||||||
|
|
||||||
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
void MklBlasGemm(OpKernelContext* ctx, bool transa, bool transb, const int m,
|
||||||
|
@ -95,7 +95,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
void Execute(const Tinput* src_data, const Tweight* weight_data,
|
void Execute(const Tinput* src_data, const Tweight* weight_data,
|
||||||
const Tbias* bias_data, Toutput* dst_data,
|
const Tbias* bias_data, Toutput* dst_data,
|
||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||||
context_.weight_mem->set_data_handle(
|
context_.weight_mem->set_data_handle(
|
||||||
@ -112,7 +112,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
|||||||
context_.bias_mem->set_data_handle(
|
context_.bias_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
static_cast<void*>(const_cast<Tbias*>(bias_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||||
|
|
||||||
@ -534,7 +534,7 @@ class MklMatMulPrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
void Execute(const T* a_data, const T* b_data, T* c_data,
|
void Execute(const T* a_data, const T* b_data, T* c_data,
|
||||||
std::shared_ptr<stream> stream) {
|
std::shared_ptr<stream> stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
|
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)),
|
||||||
*stream);
|
*stream);
|
||||||
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
|
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)),
|
||||||
@ -545,7 +545,7 @@ class MklMatMulPrimitive : public MklPrimitive {
|
|||||||
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
|
context_.a_mem->set_data_handle(static_cast<void*>(const_cast<T*>(a_data)));
|
||||||
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
|
context_.b_mem->set_data_handle(static_cast<void*>(const_cast<T*>(b_data)));
|
||||||
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
|
context_.c_mem->set_data_handle(static_cast<void*>(const_cast<T*>(c_data)));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
execute_primitives(context_.matmul_primitives, stream, context_.net_args);
|
execute_primitives(context_.matmul_primitives, stream, context_.net_args);
|
||||||
|
|
||||||
// After execution, set data handle back
|
// After execution, set data handle back
|
||||||
|
@ -86,7 +86,7 @@ template <typename T>
|
|||||||
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||||
void* ws_data,
|
void* ws_data,
|
||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), *fwd_stream);
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), *fwd_stream);
|
||||||
@ -106,7 +106,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
|||||||
DCHECK(ws_data != nullptr);
|
DCHECK(ws_data != nullptr);
|
||||||
context_.ws_mem->set_data_handle(ws_data);
|
context_.ws_mem->set_data_handle(ws_data);
|
||||||
}
|
}
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||||
|
|
||||||
// Set back data handle.
|
// Set back data handle.
|
||||||
@ -188,7 +188,7 @@ template <typename T>
|
|||||||
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||||
T* diff_src_data, const void* ws_data,
|
T* diff_src_data, const void* ws_data,
|
||||||
std::shared_ptr<stream> bwd_stream) {
|
std::shared_ptr<stream> bwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
|
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
|
||||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
|
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
|
||||||
@ -205,7 +205,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
|||||||
DCHECK(ws_data != nullptr);
|
DCHECK(ws_data != nullptr);
|
||||||
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
|
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
|
||||||
}
|
}
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||||
|
|
||||||
|
@ -431,7 +431,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
((max_input - min_input) *
|
((max_input - min_input) *
|
||||||
std::max(std::abs(max_weight), std::abs(min_weight)));
|
std::max(std::abs(max_weight), std::abs(min_weight)));
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
auto parallel_func = [&](int64 start, int64 end) {
|
auto parallel_func = [&](int64 start, int64 end) {
|
||||||
for (int64 j = start; j < end; j++) {
|
for (int64 j = start; j < end; j++) {
|
||||||
int x = 0;
|
int x = 0;
|
||||||
@ -460,7 +460,7 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase<Tweight, Toutput> {
|
|||||||
comp_bias[j] =
|
comp_bias[j] =
|
||||||
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
|
((bias_buf[j] * out_scale) + static_cast<float>(x * qa_amin));
|
||||||
}
|
}
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
return reinterpret_cast<Tbias*>(comp_bias_);
|
return reinterpret_cast<Tbias*>(comp_bias_);
|
||||||
|
|
||||||
} else if (mode_ == QUANTIZE_MODE_SCALED) {
|
} else if (mode_ == QUANTIZE_MODE_SCALED) {
|
||||||
|
@ -87,13 +87,13 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
void Execute(void* src_data, void* dst_data,
|
void Execute(void* src_data, void* dst_data,
|
||||||
std::shared_ptr<stream> reorder_stream) {
|
std::shared_ptr<stream> reorder_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(src_data, *reorder_stream);
|
context_.src_mem->set_data_handle(src_data, *reorder_stream);
|
||||||
context_.dst_mem->set_data_handle(dst_data, *reorder_stream);
|
context_.dst_mem->set_data_handle(dst_data, *reorder_stream);
|
||||||
#else
|
#else
|
||||||
context_.src_mem->set_data_handle(src_data);
|
context_.src_mem->set_data_handle(src_data);
|
||||||
context_.dst_mem->set_data_handle(dst_data);
|
context_.dst_mem->set_data_handle(dst_data);
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
context_.reorder_prim->execute(*reorder_stream, context_.prim_args);
|
context_.reorder_prim->execute(*reorder_stream, context_.prim_args);
|
||||||
// After execution, set data handle back.
|
// After execution, set data handle back.
|
||||||
context_.src_mem->set_data_handle(DummyData);
|
context_.src_mem->set_data_handle(DummyData);
|
||||||
|
@ -70,9 +70,9 @@ void MklQuantizationRangeForMultiplication(float min_a, float max_a,
|
|||||||
float* min_c = (*min_c_vector)->flat<float>().data();
|
float* min_c = (*min_c_vector)->flat<float>().data();
|
||||||
float* max_c = (*max_c_vector)->flat<float>().data();
|
float* max_c = (*max_c_vector)->flat<float>().data();
|
||||||
|
|
||||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
#ifdef ENABLE_ONEDNN_OPENMP
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
#endif // !ENABLE_MKLDNN_THREADPOOL
|
#endif // ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Add eigen parallel_for
|
// TODO: Add eigen parallel_for
|
||||||
for (int64_t n = 0; n < n_channel; ++n) {
|
for (int64_t n = 0; n < n_channel; ++n) {
|
||||||
float a_float_for_one_quant_level =
|
float a_float_for_one_quant_level =
|
||||||
|
@ -74,7 +74,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
// dst_data: output data buffer of dst
|
// dst_data: output data buffer of dst
|
||||||
void Execute(const T* src_data, T* dst_data,
|
void Execute(const T* src_data, T* dst_data,
|
||||||
std::shared_ptr<stream> fwd_stream) {
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
||||||
@ -83,7 +83,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)));
|
static_cast<void*>(const_cast<T*>(src_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||||
context_.fwd_primitives_args.size());
|
context_.fwd_primitives_args.size());
|
||||||
execute_primitives(context_.fwd_primitives, fwd_stream,
|
execute_primitives(context_.fwd_primitives, fwd_stream,
|
||||||
@ -255,7 +255,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
|||||||
// diff_src_data: output data buffer of diff_src
|
// diff_src_data: output data buffer of diff_src
|
||||||
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
|
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
|
||||||
std::shared_ptr<stream> bwd_stream) {
|
std::shared_ptr<stream> bwd_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
|
||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
@ -268,7 +268,7 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
|||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
DCHECK_EQ(context_.bwd_primitives.size(),
|
DCHECK_EQ(context_.bwd_primitives.size(),
|
||||||
context_.bwd_primitives_args.size());
|
context_.bwd_primitives_args.size());
|
||||||
execute_primitives(context_.bwd_primitives, bwd_stream,
|
execute_primitives(context_.bwd_primitives, bwd_stream,
|
||||||
|
@ -76,13 +76,13 @@ class MklRequantizationRangePerChannelOp : public OpKernel {
|
|||||||
// Find the ranges of each channel in parallel.
|
// Find the ranges of each channel in parallel.
|
||||||
float out_min_max = std::numeric_limits<float>::min();
|
float out_min_max = std::numeric_limits<float>::min();
|
||||||
|
|
||||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
#ifdef ENABLE_ONEDNN_OPENMP
|
||||||
#ifdef _MSC_VER
|
#ifdef _MSC_VER
|
||||||
#pragma omp parallel for
|
#pragma omp parallel for
|
||||||
#else
|
#else
|
||||||
#pragma omp parallel for reduction(max : out_min_max)
|
#pragma omp parallel for reduction(max : out_min_max)
|
||||||
#endif
|
#endif
|
||||||
#endif // !ENABLE_MKLDNN_THREADPOOL
|
#endif // ENABLE_ONEDNN_OPENMP
|
||||||
// TODO: Add eigen parallel_for
|
// TODO: Add eigen parallel_for
|
||||||
for (int64_t i = 0; i < depth; ++i) {
|
for (int64_t i = 0; i < depth; ++i) {
|
||||||
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
|
Eigen::Tensor<qint32, 0, Eigen::RowMajor> min =
|
||||||
|
@ -185,7 +185,7 @@ class MklSlicePrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
void Execute(const MklSliceParams& sliceParams,
|
void Execute(const MklSliceParams& sliceParams,
|
||||||
std::shared_ptr<stream> slice_stream) {
|
std::shared_ptr<stream> slice_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
|
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
|
||||||
*slice_stream);
|
*slice_stream);
|
||||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
|
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
|
||||||
@ -193,7 +193,7 @@ class MklSlicePrimitive : public MklPrimitive {
|
|||||||
#else
|
#else
|
||||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
||||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
execute_primitives(context_.slice_primitives, slice_stream,
|
execute_primitives(context_.slice_primitives, slice_stream,
|
||||||
context_.slice_primitives_args);
|
context_.slice_primitives_args);
|
||||||
|
@ -58,7 +58,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
|||||||
// dst_data: output data buffer of dst
|
// dst_data: output data buffer of dst
|
||||||
void Execute(const T* src_data, T* dst_data,
|
void Execute(const T* src_data, T* dst_data,
|
||||||
std::shared_ptr<stream> fwd_cpu_stream) {
|
std::shared_ptr<stream> fwd_cpu_stream) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream);
|
static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream);
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
||||||
@ -67,7 +67,7 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
|||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)));
|
static_cast<void*>(const_cast<T*>(src_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
|
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
|
||||||
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,
|
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,
|
||||||
|
@ -32,7 +32,7 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
using dnnl::stream_attr;
|
using dnnl::stream_attr;
|
||||||
using dnnl::threadpool_iface;
|
using dnnl::threadpool_iface;
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ struct MklDnnThreadPool {
|
|||||||
MklDnnThreadPool(OpKernelContext* ctx) {}
|
MklDnnThreadPool(OpKernelContext* ctx) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -224,7 +224,7 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) {
|
|||||||
|
|
||||||
inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
|
inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
|
||||||
const engine& engine) {
|
const engine& engine) {
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
stream_attr tp_stream_attr(engine::kind::cpu);
|
stream_attr tp_stream_attr(engine::kind::cpu);
|
||||||
if (eigen_tp != nullptr) {
|
if (eigen_tp != nullptr) {
|
||||||
tp_stream_attr.set_threadpool(eigen_tp);
|
tp_stream_attr.set_threadpool(eigen_tp);
|
||||||
@ -238,7 +238,7 @@ inline mkldnn::stream* CreateStream(MklDnnThreadPool* eigen_tp,
|
|||||||
#else
|
#else
|
||||||
stream* tp_stream = new stream(engine);
|
stream* tp_stream = new stream(engine);
|
||||||
return tp_stream;
|
return tp_stream;
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
}
|
}
|
||||||
|
|
||||||
class MklDnnShape {
|
class MklDnnShape {
|
||||||
@ -1390,11 +1390,11 @@ class MklDnnData {
|
|||||||
std::shared_ptr<stream> t_stream = nullptr) {
|
std::shared_ptr<stream> t_stream = nullptr) {
|
||||||
CHECK_NOTNULL(user_memory_);
|
CHECK_NOTNULL(user_memory_);
|
||||||
CHECK_NOTNULL(data_buffer);
|
CHECK_NOTNULL(data_buffer);
|
||||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
#ifndef ENABLE_ONEDNN_OPENMP
|
||||||
user_memory_->set_data_handle(data_buffer, *t_stream);
|
user_memory_->set_data_handle(data_buffer, *t_stream);
|
||||||
#else
|
#else
|
||||||
user_memory_->set_data_handle(data_buffer);
|
user_memory_->set_data_handle(data_buffer);
|
||||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
#endif // !ENABLE_ONEDNN_OPENMP
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set function for data buffer of user memory primitive.
|
/// Set function for data buffer of user memory primitive.
|
||||||
|
@ -39,7 +39,7 @@ load(
|
|||||||
load(
|
load(
|
||||||
"//third_party/mkl_dnn:build_defs.bzl",
|
"//third_party/mkl_dnn:build_defs.bzl",
|
||||||
"if_mkl_open_source_only",
|
"if_mkl_open_source_only",
|
||||||
"if_mkldnn_threadpool",
|
"if_mkldnn_openmp",
|
||||||
)
|
)
|
||||||
load("@bazel_skylib//lib:new_sets.bzl", "sets")
|
load("@bazel_skylib//lib:new_sets.bzl", "sets")
|
||||||
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
|
load("@bazel_skylib//rules:common_settings.bzl", "BuildSettingInfo")
|
||||||
@ -364,7 +364,7 @@ def tf_copts(
|
|||||||
if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) +
|
if_xla_available(["-DTENSORFLOW_USE_XLA=1"]) +
|
||||||
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
|
if_tensorrt(["-DGOOGLE_TENSORRT=1"]) +
|
||||||
if_mkl(["-DINTEL_MKL=1"]) +
|
if_mkl(["-DINTEL_MKL=1"]) +
|
||||||
if_mkldnn_threadpool(["-DENABLE_MKLDNN_THREADPOOL"]) +
|
if_mkldnn_openmp(["-DENABLE_ONEDNN_OPENMP"]) +
|
||||||
if_enable_mkl(["-DENABLE_MKL"]) +
|
if_enable_mkl(["-DENABLE_MKL"]) +
|
||||||
if_android_arm(["-mfpu=neon"]) +
|
if_android_arm(["-mfpu=neon"]) +
|
||||||
if_linux_x86_64(["-msse3"]) +
|
if_linux_x86_64(["-msse3"]) +
|
||||||
|
5
third_party/mkl_dnn/BUILD
vendored
5
third_party/mkl_dnn/BUILD
vendored
@ -19,11 +19,10 @@ config_setting(
|
|||||||
)
|
)
|
||||||
|
|
||||||
config_setting(
|
config_setting(
|
||||||
name = "build_with_mkldnn_threadpool",
|
name = "build_with_mkldnn_openmp",
|
||||||
define_values = {
|
define_values = {
|
||||||
"build_with_mkl": "true",
|
"build_with_mkl": "true",
|
||||||
"build_with_mkl_opensource": "true",
|
"build_with_openmp": "true",
|
||||||
"build_with_mkldnn_threadpool": "true",
|
|
||||||
},
|
},
|
||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
)
|
)
|
||||||
|
10
third_party/mkl_dnn/build_defs.bzl
vendored
10
third_party/mkl_dnn/build_defs.bzl
vendored
@ -14,18 +14,18 @@ def if_mkl_open_source_only(if_true, if_false = []):
|
|||||||
"//conditions:default": if_false,
|
"//conditions:default": if_false,
|
||||||
})
|
})
|
||||||
|
|
||||||
def if_mkldnn_threadpool(if_true, if_false = []):
|
def if_mkldnn_openmp(if_true, if_false = []):
|
||||||
"""Returns `if_true` if MKL-DNN v1.x is used.
|
"""Returns `if_true` if OpenMP is used with oneDNN.
|
||||||
|
|
||||||
Shorthand for select()'ing on whether we're building with
|
Shorthand for select()'ing on whether we're building with
|
||||||
MKL-DNN v1.x open source library only with user specified threadpool, without depending on MKL binary form.
|
oneDNN open source library only with openmp
|
||||||
|
|
||||||
Returns a select statement which evaluates to if_true if we're building
|
Returns a select statement which evaluates to if_true if we're building
|
||||||
with MKL-DNN v1.x open source library only with user specified threadpool. Otherwise, the
|
with oneDNN open source library only with OpenMP. Otherwise, the
|
||||||
select statement evaluates to if_false.
|
select statement evaluates to if_false.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return select({
|
return select({
|
||||||
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": if_true,
|
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_openmp": if_true,
|
||||||
"//conditions:default": if_false,
|
"//conditions:default": if_false,
|
||||||
})
|
})
|
||||||
|
7
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
7
third_party/mkl_dnn/mkldnn_v1.BUILD
vendored
@ -10,8 +10,7 @@ load(
|
|||||||
)
|
)
|
||||||
load(
|
load(
|
||||||
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
|
"@org_tensorflow//third_party/mkl_dnn:build_defs.bzl",
|
||||||
"if_mkl_open_source_only",
|
"if_mkldnn_openmp",
|
||||||
"if_mkldnn_threadpool",
|
|
||||||
)
|
)
|
||||||
load(
|
load(
|
||||||
"@org_tensorflow//third_party/mkl:build_defs.bzl",
|
"@org_tensorflow//third_party/mkl:build_defs.bzl",
|
||||||
@ -45,8 +44,8 @@ template_rule(
|
|||||||
src = "include/dnnl_config.h.in",
|
src = "include/dnnl_config.h.in",
|
||||||
out = "include/dnnl_config.h",
|
out = "include/dnnl_config.h",
|
||||||
substitutions = select({
|
substitutions = select({
|
||||||
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_threadpool": _DNNL_RUNTIME_THREADPOOL,
|
"@org_tensorflow//third_party/mkl_dnn:build_with_mkldnn_openmp": _DNNL_RUNTIME_OMP,
|
||||||
"@org_tensorflow//third_party/mkl:build_with_mkl": _DNNL_RUNTIME_OMP,
|
"@org_tensorflow//third_party/mkl:build_with_mkl": _DNNL_RUNTIME_THREADPOOL,
|
||||||
"//conditions:default": _DNNL_RUNTIME_SEQ,
|
"//conditions:default": _DNNL_RUNTIME_SEQ,
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user