Merge pull request #38259 from Intel-tensorflow:sriniva2/dnnl_threadpool

PiperOrigin-RevId: 307948123
Change-Id: Ic1210d1b6d48c4f333a619c014928a588ec02ed3
This commit is contained in:
TensorFlower Gardener 2020-04-22 18:36:15 -07:00
commit a2f67253fe
9 changed files with 249 additions and 70 deletions

View File

@ -101,8 +101,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS"); const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
return (val && strings::safe_strto32(val, &num)) ? num : 0; return (val && strings::safe_strto32(val, &num)) ? num : 0;
} }
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
#ifdef INTEL_MKL
int32 OMPThreadsFromEnvironment() { int32 OMPThreadsFromEnvironment() {
// 1) std::getenv is thread-safe (as long as no other function modifies the // 1) std::getenv is thread-safe (as long as no other function modifies the
// host env) from C++11 onward. 2) Most of TF code (except tests and // host env) from C++11 onward. 2) Most of TF code (except tests and
@ -122,14 +121,14 @@ int32 DefaultNumIntraOpThreads() {
// Default to the maximum parallelism for the current process. // Default to the maximum parallelism for the current process.
return port::MaxParallelism(); return port::MaxParallelism();
} }
#endif // INTEL_MKL #endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 inter_op = options.config.inter_op_parallelism_threads(); const int32 inter_op = options.config.inter_op_parallelism_threads();
if (inter_op > 0) return inter_op; if (inter_op > 0) return inter_op;
const int32 env_inter_op = GetEnvNumInterOpThreads(); const int32 env_inter_op = GetEnvNumInterOpThreads();
if (env_inter_op > 0) return env_inter_op; if (env_inter_op > 0) return env_inter_op;
#ifdef INTEL_MKL #if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
if (!DisableMKL()) { if (!DisableMKL()) {
// MKL library executes ops in parallel using OMP threads. // MKL library executes ops in parallel using OMP threads.
// Setting inter_op conservatively to avoid thread oversubscription that // Setting inter_op conservatively to avoid thread oversubscription that
@ -150,7 +149,7 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
<< ". Tune using inter_op_parallelism_threads for best performance."; << ". Tune using inter_op_parallelism_threads for best performance.";
return mkl_inter_op; return mkl_inter_op;
} }
#endif // INTEL_MKL #endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
return DefaultNumInterOpThreads(); return DefaultNumInterOpThreads();
} }

View File

@ -50,7 +50,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
name, DEVICE_CPU, memory_limit, locality)), name, DEVICE_CPU, memory_limit, locality)),
allocator_(allocator), allocator_(allocator),
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) { scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
#ifdef INTEL_MKL #if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
// Early return when MKL is disabled // Early return when MKL is disabled
if (DisableMKL()) return; if (DisableMKL()) return;
#ifdef _OPENMP #ifdef _OPENMP
@ -69,7 +69,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
} }
} }
#endif // _OPENMP #endif // _OPENMP
#endif // INTEL_MKL #endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
} }
ThreadPoolDevice::~ThreadPoolDevice() {} ThreadPoolDevice::~ThreadPoolDevice() {}

View File

@ -51,12 +51,10 @@ limitations under the License.
using mkldnn::convolution_forward; using mkldnn::convolution_forward;
using mkldnn::prop_kind; using mkldnn::prop_kind;
using mkldnn::stream; using mkldnn::stream;
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc; using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
using ReorderPd = mkldnn::reorder::primitive_desc; using ReorderPd = mkldnn::reorder::primitive_desc;
namespace tensorflow { namespace tensorflow {
// This structure aggregates multiple inputs to Conv2DFwd* methods. // This structure aggregates multiple inputs to Conv2DFwd* methods.
struct MklConvFwdParams { struct MklConvFwdParams {
memory::dims src_dims; memory::dims src_dims;
@ -96,14 +94,12 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
class MklConvFwdPrimitive : public MklPrimitive { class MklConvFwdPrimitive : public MklPrimitive {
public: public:
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: cpu_engine_(ENGINE_CPU, 0) { : MklPrimitive(engine(ENGINE_CPU, 0)) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
// Create convolution primitive // Create convolution primitive
if (context_.conv_fwd == nullptr) { if (context_.conv_fwd == nullptr) {
Setup(convFwdDims); Setup(convFwdDims);
} }
} }
~MklConvFwdPrimitive() {} ~MklConvFwdPrimitive() {}
// Convolution forward execute with bias // Convolution forward execute with bias
@ -112,7 +108,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
// bias_data: input data buffer of bias // bias_data: input data buffer of bias
// dst_data: output data buffer of dst // dst_data: output data buffer of dst
void Execute(const Tinput* src_data, const Tfilter* filter_data, void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Tbias* bias_data, const Toutput* dst_data) { const Tbias* bias_data, const Toutput* dst_data,
std::shared_ptr<stream> fwd_stream) {
context_.src_mem->set_data_handle( context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data))); static_cast<void*>(const_cast<Tinput*>(src_data)));
context_.filter_mem->set_data_handle( context_.filter_mem->set_data_handle(
@ -127,11 +124,11 @@ class MklConvFwdPrimitive : public MklPrimitive {
DCHECK_EQ(context_.fwd_primitives.size(), DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size()); context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream, context_.fwd_primitives.at(i).execute(*fwd_stream,
context_.fwd_primitives_args.at(i)); context_.fwd_primitives_args.at(i));
} }
#else #else
context_.fwd_stream->submit(context_.fwd_primitives); fwd_stream->submit(context_.fwd_primitives);
#endif // ENABLE_MKLDNN_V1 #endif // ENABLE_MKLDNN_V1
// After execution, set data handle back // After execution, set data handle back
@ -148,8 +145,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
// filter_data: input data buffer of filter (weights) // filter_data: input data buffer of filter (weights)
// dst_data: output data buffer of dst // dst_data: output data buffer of dst
void Execute(const Tinput* src_data, const Tfilter* filter_data, void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Toutput* dst_data) { const Toutput* dst_data, std::shared_ptr<stream> fwd_stream) {
Execute(src_data, filter_data, nullptr, dst_data); Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
} }
#ifndef ENABLE_MKLDNN_V1 #ifndef ENABLE_MKLDNN_V1
@ -191,7 +188,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
std::shared_ptr<ConvFwdPd> fwd_pd; std::shared_ptr<ConvFwdPd> fwd_pd;
std::shared_ptr<mkldnn::primitive> conv_fwd; std::shared_ptr<mkldnn::primitive> conv_fwd;
std::shared_ptr<mkldnn::stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives; std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
@ -213,8 +209,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
filter_md(nullptr), filter_md(nullptr),
bias_md(nullptr), bias_md(nullptr),
fwd_pd(nullptr), fwd_pd(nullptr),
conv_fwd(nullptr), conv_fwd(nullptr) {
fwd_stream(nullptr) {
} }
}; };
@ -346,7 +341,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
} }
struct ConvFwdContext context_; struct ConvFwdContext context_;
engine cpu_engine_;
}; };
// TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory. // TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory.
@ -678,11 +672,9 @@ class MklConvOp : public OpKernel {
// TODO(mdfaijul): Extend the basic parameters for data types and fusions // TODO(mdfaijul): Extend the basic parameters for data types and fusions
this->ExtendConvFwdParams(context, convFwdDims); this->ExtendConvFwdParams(context, convFwdDims);
conv_fwd = conv_fwd =
MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get( MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
convFwdDims, do_not_cache); convFwdDims, do_not_cache);
// Allocate output tensors `output_tensor` and `filter_out_tensor` // Allocate output tensors `output_tensor` and `filter_out_tensor`
MklDnnShape output_mkl_shape; MklDnnShape output_mkl_shape;
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
@ -703,8 +695,10 @@ class MklConvOp : public OpKernel {
Tinput* src_data = nullptr; Tinput* src_data = nullptr;
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) { if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
src.SetUsrMem(src_md, &src_tensor); src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( src.CheckReorderToOpMem(
GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_)); MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
cpu_engine_),
context);
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle()); src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
} else { } else {
src_data = static_cast<Tinput*>( src_data = static_cast<Tinput*>(
@ -735,13 +729,16 @@ class MklConvOp : public OpKernel {
if (!is_filter_cached) { if (!is_filter_cached) {
filter.SetUsrMem(filter_md, &filter_tensor); filter.SetUsrMem(filter_md, &filter_tensor);
if (filter_out_tensor == nullptr) { if (filter_out_tensor == nullptr) {
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( filter.CheckReorderToOpMem(
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_)); MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
cpu_engine_),
context);
} else { } else {
filter.CheckReorderToOpMem( filter.CheckReorderToOpMem(
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor), DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
cpu_engine_)); cpu_engine_),
context);
} }
filter_data = filter_data =
static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
@ -752,20 +749,23 @@ class MklConvOp : public OpKernel {
} }
// Execute convolution // Execute convolution
std::shared_ptr<stream> fwd_cpu_stream;
fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine()));
if (fuse_biasadd_) { if (fuse_biasadd_) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
Tbias* bias_data = Tbias* bias_data =
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data); conv_fwd->Execute(src_data, filter_data, bias_data, dst_data,
fwd_cpu_stream);
} else { } else {
if (!eager_mode) { if (!eager_mode) {
conv_fwd->Execute(src_data, filter_data, dst_data); conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream);
} else { } else {
// In eager mode we first write the output to temporary // In eager mode we first write the output to temporary
// buffer in MKL format. Then we convert the data to TF format. // buffer in MKL format. Then we convert the data to TF format.
Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>( Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
tmp_tensor.flat<Toutput>().data()); tmp_tensor.flat<Toutput>().data());
conv_fwd->Execute(src_data, filter_data, tmp_data); conv_fwd->Execute(src_data, filter_data, tmp_data, fwd_cpu_stream);
// Now we need to convert the output to TF format. // Now we need to convert the output to TF format.
auto output_tf_md = output_mkl_shape.GetTfLayout(); auto output_tf_md = output_mkl_shape.GetTfLayout();
@ -780,12 +780,13 @@ class MklConvOp : public OpKernel {
memory* dst_data_mem = memory* dst_data_mem =
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data); new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data);
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem, CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
cpu_engine_); cpu_engine_, context);
} }
} }
// Delete primitive since it is not cached. // Delete primitive since it is not cached.
if (do_not_cache) delete conv_fwd; if (do_not_cache) delete conv_fwd;
} catch (mkldnn::error& e) { } catch (mkldnn::error& e) {
string error_msg = tensorflow::strings::StrCat( string error_msg = tensorflow::strings::StrCat(
"Status: ", e.status, ", message: ", string(e.message), ", in file ", "Status: ", e.status, ", message: ", string(e.message), ", in file ",
@ -970,8 +971,9 @@ class MklConvOp : public OpKernel {
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf)); new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
auto reorder_desc = auto reorder_desc =
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_); REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_, CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
this->cpu_engine_); this->cpu_engine_, context);
} }
} else { } else {
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
@ -1097,6 +1099,7 @@ class MklConvOp : public OpKernel {
filter_tf_shape, filter_mkl_shape); filter_tf_shape, filter_mkl_shape);
} }
// TODO(intel-mkl): This function does not seem to be called. Remove it.
// Prepare and execute net - checks for input and output reorders. // Prepare and execute net - checks for input and output reorders.
void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc, void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
MklDnnData<Tinput>* src, MklDnnData<Tinput>* src,
@ -1185,7 +1188,7 @@ class MklConvOp : public OpKernel {
// Otherwise, cache filter // Otherwise, cache filter
filter.SetUsrMem(filter_md, &filter_tensor); filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(), filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
this->cpu_engine_); this->cpu_engine_, context);
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle()); filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
Tensor* filter_tensor_ptr = nullptr; Tensor* filter_tensor_ptr = nullptr;
@ -1652,7 +1655,7 @@ class MklQuantizedConv2DOp
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_, input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
bias_attr); bias_attr);
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
this->cpu_engine_); this->cpu_engine_, context);
Tbias* bias_data = Tbias* bias_data =
reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle()); reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
@ -1908,7 +1911,8 @@ class MklQuantizedConv2DSumReluOp
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_, SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
reorder_attr); reorder_attr);
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_); CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
context);
} }
std::shared_ptr<mkldnn::memory> summand_; std::shared_ptr<mkldnn::memory> summand_;

View File

@ -165,7 +165,7 @@ class MklInputConversionOp : public OpKernel {
input1_md, tensor_out, net, net_args, cpu_engine)), input1_md, tensor_out, net, net_args, cpu_engine)),
errors::Internal( errors::Internal(
"MklInputConversionOp: Failed to create reorder for input0")); "MklInputConversionOp: Failed to create reorder for input0"));
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
// Input1 will be passed through // Input1 will be passed through
ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
return; return;
@ -273,7 +273,7 @@ class MklInputConversionOp : public OpKernel {
errors::Internal("MklInputConversionOp: Failed to forward " errors::Internal("MklInputConversionOp: Failed to forward "
"input tensor to output")); "input tensor to output"));
} else { } else {
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
} }
// -- The tensor in MKL format passes through -- // -- The tensor in MKL format passes through --

View File

@ -172,7 +172,8 @@ class MklReshapeOp : public OpKernel {
// shape_from != shape_to), then we just copy input tensor to // shape_from != shape_to), then we just copy input tensor to
// output tensor with target shape (we cannot forward Mkl layout // output tensor with target shape (we cannot forward Mkl layout
// in such case because shape has changed.) // in such case because shape has changed.)
if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor)) { if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor,
context)) {
} else { } else {
OP_REQUIRES(context, OP_REQUIRES(context,
output_tensor->CopyFrom(input_tensor, shape_to), output_tensor->CopyFrom(input_tensor, shape_to),

View File

@ -111,7 +111,8 @@ class MklToTfOp : public OpKernel {
if (input.IsReorderNeeded(OUTPUT_TF_MD)) { if (input.IsReorderNeeded(OUTPUT_TF_MD)) {
// Insert reorder between MKL layout and TensorFlow layout // Insert reorder between MKL layout and TensorFlow layout
OP_REQUIRES( OP_REQUIRES(
context, input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor), context,
input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor, context),
errors::Internal("MklToTfOp: Failed to create input reorder")); errors::Internal("MklToTfOp: Failed to create input reorder"));
} else { } else {
// If not, just forward input tensor to output tensor. // If not, just forward input tensor to output tensor.

View File

@ -144,6 +144,7 @@ filegroup(
"matmul_autotune.h", "matmul_autotune.h",
"matmul_bcast.h", "matmul_bcast.h",
"mirror_pad_mode.h", "mirror_pad_mode.h",
"mkl_threadpool.h",
"mkl_types.h", "mkl_types.h",
"mkl_util.h", "mkl_util.h",
"overflow.h", "overflow.h",
@ -273,6 +274,7 @@ filegroup(
filegroup( filegroup(
name = "mkl_util_hdrs", name = "mkl_util_hdrs",
srcs = [ srcs = [
"mkl_threadpool.h",
"mkl_util.h", "mkl_util.h",
], ],
visibility = ["//tensorflow/core:__pkg__"], visibility = ["//tensorflow/core:__pkg__"],

View File

@ -0,0 +1,138 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
#ifdef INTEL_MKL
#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mkldnn.hpp"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/threadpool.h"
#define EIGEN_USE_THREADS
#ifdef ENABLE_MKLDNN_THREADPOOL
using dnnl::stream_attr;
using dnnl::threadpool_iface;
namespace tensorflow {
// Divide 'n' units of work equally among 'teams' threads. If 'n' is not
// divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
// unit of work more than the rest. Returns the range of work that belongs to
// the team 'tid'.
// Parameters
// n Total number of jobs.
// team Number of workers.
// tid Current thread_id.
// n_start start of range operated by the thread.
// n_end end of the range operated by the thread.
template <typename T, typename U>
inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
if (team <= 1 || n == 0) {
*n_start = 0;
*n_end = n;
return;
}
T min_per_team = n / team;
T remainder = n - min_per_team * team; // i.e., n % teams.
*n_start = tid * min_per_team + std::min(tid, remainder);
*n_end = *n_start + min_per_team + (tid < remainder);
}
struct MklDnnThreadPool : public dnnl::threadpool_iface {
MklDnnThreadPool() = default;
MklDnnThreadPool(OpKernelContext* ctx)
: eigen_interface_(ctx->device()
->tensorflow_cpu_worker_threads()
->workers->AsEigenThreadPool()) {}
virtual int get_num_threads() override {
return eigen_interface_->NumThreads();
}
virtual bool get_in_parallel() override {
return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
}
virtual uint64_t get_flags() override { return ASYNCHRONOUS; }
virtual void parallel_for(int n,
const std::function<void(int, int)>& fn) override {
// Should never happen (handled by DNNL)
if (n == 0) return;
// Should never happen (handled by DNNL)
if (n == 1) {
fn(0, 1);
return;
}
int nthr = get_num_threads();
int njobs = std::min(n, nthr);
for (int i = 0; i < njobs; i++) {
eigen_interface_->ScheduleWithHint(
[i, n, njobs, fn]() {
int start, end;
balance211(n, njobs, i, &start, &end);
for (int j = start; j < end; j++) fn(j, n);
},
i, i + 1);
}
}
~MklDnnThreadPool() {}
private:
Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
};
class MklDnnThreadPoolWrapper {
public:
static MklDnnThreadPoolWrapper& GetInstance() {
static MklDnnThreadPoolWrapper instance_;
return instance_;
}
MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) {
if (threadpool_map_.empty() ||
threadpool_map_.find(ctx->device()) == threadpool_map_.end()) {
auto tp_iface = new MklDnnThreadPool(ctx);
threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface));
return tp_iface;
} else {
auto entry = threadpool_map_.find(ctx->device());
return entry->second;
}
}
private:
std::unordered_map<DeviceBase*, MklDnnThreadPool*> threadpool_map_;
MklDnnThreadPoolWrapper() {}
MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete;
MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete;
~MklDnnThreadPoolWrapper() {
for (auto& tp : threadpool_map_) {
delete tp.second;
}
}
};
} // namespace tensorflow
#endif // ENABLE_MKLDNN_THREADPOOL
#endif // INTEL_MKL
#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_

View File

@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/mkl_threadpool.h"
#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
@ -48,7 +49,6 @@ using mkldnn::padding_kind;
using mkldnn::primitive; using mkldnn::primitive;
using mkldnn::reorder; using mkldnn::reorder;
using mkldnn::stream; using mkldnn::stream;
using CPUDevice = Eigen::ThreadPoolDevice; using CPUDevice = Eigen::ThreadPoolDevice;
using MemoryArgsMap = std::unordered_map<int, memory>; using MemoryArgsMap = std::unordered_map<int, memory>;
using ReorderPd = mkldnn::reorder::primitive_desc; using ReorderPd = mkldnn::reorder::primitive_desc;
@ -232,6 +232,27 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) {
return true; return true;
} }
inline mkldnn::stream* CreateStream(OpKernelContext* ctx,
const engine& engine) {
#ifdef ENABLE_MKLDNN_THREADPOOL
stream_attr tp_stream_attr(ENGINE_CPU);
if (ctx != nullptr) {
auto eigen_tp =
MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
tp_stream_attr.set_threadpool(eigen_tp);
stream* tp_stream =
new stream(engine, stream::flags::default_flags, tp_stream_attr);
return tp_stream;
} else {
stream* tp_stream = new CPU_STREAM(engine);
return tp_stream;
}
#else
stream* tp_stream = new CPU_STREAM(engine);
return tp_stream;
#endif // ENABLE_MKLDNN_THREADPOOL
}
class MklDnnShape { class MklDnnShape {
private: private:
typedef struct { typedef struct {
@ -679,20 +700,21 @@ class MklDnnData;
// TODO merge with the execute_primitives. // TODO merge with the execute_primitives.
inline void ExecutePrimitive(const std::vector<primitive>& net, inline void ExecutePrimitive(const std::vector<primitive>& net,
const std::vector<MemoryArgsMap>* net_args, const std::vector<MemoryArgsMap>* net_args,
const engine& cpu_engine) { const engine& cpu_engine,
OpKernelContext* context = nullptr) {
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
DCHECK(net_args); DCHECK(net_args);
DCHECK_EQ(net.size(), net_args->size()); DCHECK_EQ(net.size(), net_args->size());
stream cpu_stream(cpu_engine); stream* cpu_stream = CreateStream(context, cpu_engine);
for (size_t i = 0; i < net.size(); ++i) { for (size_t i = 0; i < net.size(); ++i) {
net.at(i).execute(cpu_stream, net_args->at(i)); net.at(i).execute(*cpu_stream, net_args->at(i));
} }
cpu_stream.wait(); cpu_stream->wait();
delete cpu_stream;
#else #else
stream(stream::kind::eager_nostore).submit(net).wait(); stream(stream::kind::eager_nostore).submit(net).wait();
#endif // ENABLE_MKLDNN_V1 #endif // ENABLE_MKLDNN_V1
} }
template <typename T> template <typename T>
inline Status ConvertMklToTF(OpKernelContext* context, inline Status ConvertMklToTF(OpKernelContext* context,
const Tensor& input_mkl_tensor, const Tensor& input_mkl_tensor,
@ -731,7 +753,7 @@ inline Status ConvertMklToTF(OpKernelContext* context,
return Status(error::Code::INTERNAL, return Status(error::Code::INTERNAL,
"ConvertMklToTF(): Failed to create reorder for input"); "ConvertMklToTF(): Failed to create reorder for input");
} }
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
} else { } else {
// If not, just forward input tensor to output tensor. // If not, just forward input tensor to output tensor.
bool status = bool status =
@ -1301,8 +1323,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
const memory& src_mem, const memory& src_mem,
const memory& dst_mem, const memory& dst_mem, const engine& engine,
const engine& engine) { OpKernelContext* ctx = nullptr) {
std::vector<primitive> net; std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
net.push_back(mkldnn::reorder(reorder_desc)); net.push_back(mkldnn::reorder(reorder_desc));
@ -1311,7 +1333,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
#else #else
net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem)); net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
#endif // ENABLE_MKLDNN_V1 #endif // ENABLE_MKLDNN_V1
ExecutePrimitive(net, NET_ARGS_PTR, engine); ExecutePrimitive(net, NET_ARGS_PTR, engine, ctx);
} }
class MklReorderPrimitive; class MklReorderPrimitive;
@ -1629,22 +1651,26 @@ class MklDnnData {
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
inline bool CheckReorderToOpMem(const memory::desc& op_md, inline bool CheckReorderToOpMem(const memory::desc& op_md,
const engine& engine) { const engine& engine,
OpKernelContext* context = nullptr) {
DCHECK(user_memory_); DCHECK(user_memory_);
if (IsReorderNeeded(op_md)) { if (IsReorderNeeded(op_md)) {
// TODO(nhasabni): can we remove dynamic memory allocation? // TODO(nhasabni): can we remove dynamic memory allocation?
// primitive reuse don't allow two same reorder prim in // primitive reuse don't allow two same reorder prim in
// one stream, so submit it immediately // one stream, so submit it immediately
reorder_memory_ = new memory(op_md, engine); reorder_memory_ = new memory(op_md, engine);
std::vector<primitive> net;
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_); auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(context, prim->GetEngine()));
std::vector<primitive> net;
net.push_back(*(prim->GetPrimitive())); net.push_back(*(prim->GetPrimitive()));
std::vector<MemoryArgsMap> net_args; std::vector<MemoryArgsMap> net_args;
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_}, net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
{MKLDNN_ARG_TO, *reorder_memory_}}); {MKLDNN_ARG_TO, *reorder_memory_}});
execute_primitives(net, prim->GetStream(), net_args); execute_primitives(net, cpu_stream, net_args);
#else #else
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) { inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
OpKernelContext* ctx = nullptr) {
CHECK_NOTNULL(user_memory_); CHECK_NOTNULL(user_memory_);
if (IsReorderNeeded(op_pd)) { if (IsReorderNeeded(op_pd)) {
reorder_memory_ = new memory(op_pd); reorder_memory_ = new memory(op_pd);
@ -1708,7 +1734,8 @@ class MklDnnData {
/// TODO(bhavanis): Need to use reorder cache here for better performance. /// TODO(bhavanis): Need to use reorder cache here for better performance.
inline bool CheckReorderToOpMem(const memory::desc& op_md, inline bool CheckReorderToOpMem(const memory::desc& op_md,
void* reorder_data_handle, void* reorder_data_handle,
const engine& engine) { const engine& engine,
OpKernelContext* context = nullptr) {
DCHECK(reorder_data_handle); DCHECK(reorder_data_handle);
DCHECK(user_memory_); DCHECK(user_memory_);
if (IsReorderNeeded(op_md)) { if (IsReorderNeeded(op_md)) {
@ -1716,16 +1743,19 @@ class MklDnnData {
// primitive reuse don't allow two same reorder prim in // primitive reuse don't allow two same reorder prim in
// one stream, so submit it immediately // one stream, so submit it immediately
reorder_memory_ = new memory(op_md, engine, reorder_data_handle); reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
std::vector<primitive> net;
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_); auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(CreateStream(context, prim->GetEngine()));
std::vector<primitive> net;
net.push_back(*(prim->GetPrimitive())); net.push_back(*(prim->GetPrimitive()));
std::vector<MemoryArgsMap> net_args; std::vector<MemoryArgsMap> net_args;
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_}, net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
{MKLDNN_ARG_TO, *reorder_memory_}}); {MKLDNN_ARG_TO, *reorder_memory_}});
execute_primitives(net, prim->GetStream(), net_args); execute_primitives(net, cpu_stream, net_args);
#else #else
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd, inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
void* reorder_data_handle) { void* reorder_data_handle,
OpKernelContext* context = nullptr) {
CHECK_NOTNULL(reorder_data_handle); CHECK_NOTNULL(reorder_data_handle);
CHECK_NOTNULL(user_memory_); CHECK_NOTNULL(user_memory_);
if (IsReorderNeeded(op_pd)) { if (IsReorderNeeded(op_pd)) {
@ -1778,13 +1808,14 @@ class MklDnnData {
/// remove /// remove
/// slow path in the future /// slow path in the future
inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd, inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd,
Tensor* reorder_tensor) { Tensor* reorder_tensor,
OpKernelContext* ctx = nullptr) {
DCHECK(reorder_tensor); DCHECK(reorder_tensor);
#ifdef ENABLE_MKLDNN_V1 #ifdef ENABLE_MKLDNN_V1
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
*cpu_engine_); *cpu_engine_, ctx);
#else #else
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor)); return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), ctx);
#endif // ENABLE_MKLDNN_V1 #endif // ENABLE_MKLDNN_V1
} }
@ -1843,7 +1874,7 @@ class MklDnnData {
/// TODO: this is a faster path with reorder primitive cache compared with /// TODO: this is a faster path with reorder primitive cache compared with
/// InsertReorderToUserMem(net, net_args), will remove /// InsertReorderToUserMem(net, net_args), will remove
/// slow path in the future /// slow path in the future
inline void InsertReorderToUserMem() { inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
DCHECK(user_memory_); DCHECK(user_memory_);
DCHECK(reorder_memory_); DCHECK(reorder_memory_);
DCHECK(cpu_engine_); DCHECK(cpu_engine_);
@ -1857,8 +1888,8 @@ class MklDnnData {
net_args.push_back( net_args.push_back(
{{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}}); {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
std::shared_ptr<stream> cpu_stream; std::shared_ptr<stream> cpu_stream;
cpu_stream.reset(new stream(*cpu_engine_)); cpu_stream.reset(CreateStream(ctx, prim->GetEngine()));
execute_primitives(net, prim->GetStream(), net_args); execute_primitives(net, cpu_stream, net_args);
#else #else
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_)); net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
ExecutePrimitive(net, NET_ARGS_PTR, *cpu_engine_); ExecutePrimitive(net, NET_ARGS_PTR, *cpu_engine_);
@ -1870,9 +1901,12 @@ class MklDnnData {
class MklPrimitive { class MklPrimitive {
public: public:
virtual ~MklPrimitive() {} virtual ~MklPrimitive() {}
MklPrimitive() {}
MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
// Dummy data which MKL DNN never operates on // Dummy data which MKL DNN never operates on
unsigned char* DummyData = nullptr; unsigned char* DummyData = nullptr;
engine cpu_engine_ = engine(ENGINE_CPU, 0);
const engine& GetEngine() { return cpu_engine_; }
}; };
const mkldnn::memory::dims NONE_DIMS = {}; const mkldnn::memory::dims NONE_DIMS = {};
@ -2058,7 +2092,8 @@ class FactoryKeyCreator {
class MklReorderPrimitive : public MklPrimitive { class MklReorderPrimitive : public MklPrimitive {
public: public:
explicit MklReorderPrimitive(const memory* from, const memory* to) { explicit MklReorderPrimitive(const memory* from, const memory* to)
: MklPrimitive(engine(ENGINE_CPU, 0)) {
Setup(from, to); Setup(from, to);
} }
~MklReorderPrimitive() {} ~MklReorderPrimitive() {}
@ -2081,7 +2116,6 @@ class MklReorderPrimitive : public MklPrimitive {
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {} : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
} context_; } context_;
engine cpu_engine_ = engine(ENGINE_CPU, 0);
std::shared_ptr<mkldnn::stream> stream_; std::shared_ptr<mkldnn::stream> stream_;
void Setup(const memory* from, const memory* to) { void Setup(const memory* from, const memory* to) {