Merge pull request #38259 from Intel-tensorflow:sriniva2/dnnl_threadpool
PiperOrigin-RevId: 307948123 Change-Id: Ic1210d1b6d48c4f333a619c014928a588ec02ed3
This commit is contained in:
commit
a2f67253fe
@ -101,8 +101,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
|
|||||||
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
|
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
|
||||||
return (val && strings::safe_strto32(val, &num)) ? num : 0;
|
return (val && strings::safe_strto32(val, &num)) ? num : 0;
|
||||||
}
|
}
|
||||||
|
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
#ifdef INTEL_MKL
|
|
||||||
int32 OMPThreadsFromEnvironment() {
|
int32 OMPThreadsFromEnvironment() {
|
||||||
// 1) std::getenv is thread-safe (as long as no other function modifies the
|
// 1) std::getenv is thread-safe (as long as no other function modifies the
|
||||||
// host env) from C++11 onward. 2) Most of TF code (except tests and
|
// host env) from C++11 onward. 2) Most of TF code (except tests and
|
||||||
@ -122,14 +121,14 @@ int32 DefaultNumIntraOpThreads() {
|
|||||||
// Default to the maximum parallelism for the current process.
|
// Default to the maximum parallelism for the current process.
|
||||||
return port::MaxParallelism();
|
return port::MaxParallelism();
|
||||||
}
|
}
|
||||||
#endif // INTEL_MKL
|
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
||||||
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
||||||
if (inter_op > 0) return inter_op;
|
if (inter_op > 0) return inter_op;
|
||||||
const int32 env_inter_op = GetEnvNumInterOpThreads();
|
const int32 env_inter_op = GetEnvNumInterOpThreads();
|
||||||
if (env_inter_op > 0) return env_inter_op;
|
if (env_inter_op > 0) return env_inter_op;
|
||||||
|
|
||||||
#ifdef INTEL_MKL
|
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
if (!DisableMKL()) {
|
if (!DisableMKL()) {
|
||||||
// MKL library executes ops in parallel using OMP threads.
|
// MKL library executes ops in parallel using OMP threads.
|
||||||
// Setting inter_op conservatively to avoid thread oversubscription that
|
// Setting inter_op conservatively to avoid thread oversubscription that
|
||||||
@ -150,7 +149,7 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
|||||||
<< ". Tune using inter_op_parallelism_threads for best performance.";
|
<< ". Tune using inter_op_parallelism_threads for best performance.";
|
||||||
return mkl_inter_op;
|
return mkl_inter_op;
|
||||||
}
|
}
|
||||||
#endif // INTEL_MKL
|
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
return DefaultNumInterOpThreads();
|
return DefaultNumInterOpThreads();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
|||||||
name, DEVICE_CPU, memory_limit, locality)),
|
name, DEVICE_CPU, memory_limit, locality)),
|
||||||
allocator_(allocator),
|
allocator_(allocator),
|
||||||
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
|
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
|
||||||
#ifdef INTEL_MKL
|
#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
// Early return when MKL is disabled
|
// Early return when MKL is disabled
|
||||||
if (DisableMKL()) return;
|
if (DisableMKL()) return;
|
||||||
#ifdef _OPENMP
|
#ifdef _OPENMP
|
||||||
@ -69,7 +69,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif // _OPENMP
|
#endif // _OPENMP
|
||||||
#endif // INTEL_MKL
|
#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL)
|
||||||
}
|
}
|
||||||
|
|
||||||
ThreadPoolDevice::~ThreadPoolDevice() {}
|
ThreadPoolDevice::~ThreadPoolDevice() {}
|
||||||
|
@ -51,12 +51,10 @@ limitations under the License.
|
|||||||
using mkldnn::convolution_forward;
|
using mkldnn::convolution_forward;
|
||||||
using mkldnn::prop_kind;
|
using mkldnn::prop_kind;
|
||||||
using mkldnn::stream;
|
using mkldnn::stream;
|
||||||
|
|
||||||
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
||||||
using ReorderPd = mkldnn::reorder::primitive_desc;
|
using ReorderPd = mkldnn::reorder::primitive_desc;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// This structure aggregates multiple inputs to Conv2DFwd* methods.
|
// This structure aggregates multiple inputs to Conv2DFwd* methods.
|
||||||
struct MklConvFwdParams {
|
struct MklConvFwdParams {
|
||||||
memory::dims src_dims;
|
memory::dims src_dims;
|
||||||
@ -96,14 +94,12 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
|
|||||||
class MklConvFwdPrimitive : public MklPrimitive {
|
class MklConvFwdPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||||
: cpu_engine_(ENGINE_CPU, 0) {
|
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
|
||||||
// Create convolution primitive
|
// Create convolution primitive
|
||||||
if (context_.conv_fwd == nullptr) {
|
if (context_.conv_fwd == nullptr) {
|
||||||
Setup(convFwdDims);
|
Setup(convFwdDims);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
~MklConvFwdPrimitive() {}
|
~MklConvFwdPrimitive() {}
|
||||||
|
|
||||||
// Convolution forward execute with bias
|
// Convolution forward execute with bias
|
||||||
@ -112,7 +108,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
// bias_data: input data buffer of bias
|
// bias_data: input data buffer of bias
|
||||||
// dst_data: output data buffer of dst
|
// dst_data: output data buffer of dst
|
||||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||||
const Tbias* bias_data, const Toutput* dst_data) {
|
const Tbias* bias_data, const Toutput* dst_data,
|
||||||
|
std::shared_ptr<stream> fwd_stream) {
|
||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||||
context_.filter_mem->set_data_handle(
|
context_.filter_mem->set_data_handle(
|
||||||
@ -127,11 +124,11 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||||
context_.fwd_primitives_args.size());
|
context_.fwd_primitives_args.size());
|
||||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||||
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
context_.fwd_primitives.at(i).execute(*fwd_stream,
|
||||||
context_.fwd_primitives_args.at(i));
|
context_.fwd_primitives_args.at(i));
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
fwd_stream->submit(context_.fwd_primitives);
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// After execution, set data handle back
|
// After execution, set data handle back
|
||||||
@ -148,8 +145,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
// filter_data: input data buffer of filter (weights)
|
// filter_data: input data buffer of filter (weights)
|
||||||
// dst_data: output data buffer of dst
|
// dst_data: output data buffer of dst
|
||||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||||
const Toutput* dst_data) {
|
const Toutput* dst_data, std::shared_ptr<stream> fwd_stream) {
|
||||||
Execute(src_data, filter_data, nullptr, dst_data);
|
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
@ -191,7 +188,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<ConvFwdPd> fwd_pd;
|
std::shared_ptr<ConvFwdPd> fwd_pd;
|
||||||
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
||||||
|
|
||||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
|
||||||
std::vector<mkldnn::primitive> fwd_primitives;
|
std::vector<mkldnn::primitive> fwd_primitives;
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
@ -213,8 +209,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
filter_md(nullptr),
|
filter_md(nullptr),
|
||||||
bias_md(nullptr),
|
bias_md(nullptr),
|
||||||
fwd_pd(nullptr),
|
fwd_pd(nullptr),
|
||||||
conv_fwd(nullptr),
|
conv_fwd(nullptr) {
|
||||||
fwd_stream(nullptr) {
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -346,7 +341,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
|||||||
}
|
}
|
||||||
|
|
||||||
struct ConvFwdContext context_;
|
struct ConvFwdContext context_;
|
||||||
engine cpu_engine_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory.
|
// TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory.
|
||||||
@ -678,11 +672,9 @@ class MklConvOp : public OpKernel {
|
|||||||
|
|
||||||
// TODO(mdfaijul): Extend the basic parameters for data types and fusions
|
// TODO(mdfaijul): Extend the basic parameters for data types and fusions
|
||||||
this->ExtendConvFwdParams(context, convFwdDims);
|
this->ExtendConvFwdParams(context, convFwdDims);
|
||||||
|
|
||||||
conv_fwd =
|
conv_fwd =
|
||||||
MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
|
MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
|
||||||
convFwdDims, do_not_cache);
|
convFwdDims, do_not_cache);
|
||||||
|
|
||||||
// Allocate output tensors `output_tensor` and `filter_out_tensor`
|
// Allocate output tensors `output_tensor` and `filter_out_tensor`
|
||||||
MklDnnShape output_mkl_shape;
|
MklDnnShape output_mkl_shape;
|
||||||
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
|
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
|
||||||
@ -703,8 +695,10 @@ class MklConvOp : public OpKernel {
|
|||||||
Tinput* src_data = nullptr;
|
Tinput* src_data = nullptr;
|
||||||
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
src.CheckReorderToOpMem(
|
||||||
GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||||
|
cpu_engine_),
|
||||||
|
context);
|
||||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||||
} else {
|
} else {
|
||||||
src_data = static_cast<Tinput*>(
|
src_data = static_cast<Tinput*>(
|
||||||
@ -735,13 +729,16 @@ class MklConvOp : public OpKernel {
|
|||||||
if (!is_filter_cached) {
|
if (!is_filter_cached) {
|
||||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||||
if (filter_out_tensor == nullptr) {
|
if (filter_out_tensor == nullptr) {
|
||||||
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
filter.CheckReorderToOpMem(
|
||||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
|
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||||
|
cpu_engine_),
|
||||||
|
context);
|
||||||
} else {
|
} else {
|
||||||
filter.CheckReorderToOpMem(
|
filter.CheckReorderToOpMem(
|
||||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||||
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
||||||
cpu_engine_));
|
cpu_engine_),
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
filter_data =
|
filter_data =
|
||||||
static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||||
@ -752,20 +749,23 @@ class MklConvOp : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute convolution
|
// Execute convolution
|
||||||
|
std::shared_ptr<stream> fwd_cpu_stream;
|
||||||
|
fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine()));
|
||||||
if (fuse_biasadd_) {
|
if (fuse_biasadd_) {
|
||||||
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
|
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
|
||||||
Tbias* bias_data =
|
Tbias* bias_data =
|
||||||
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
|
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
|
||||||
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
|
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data,
|
||||||
|
fwd_cpu_stream);
|
||||||
} else {
|
} else {
|
||||||
if (!eager_mode) {
|
if (!eager_mode) {
|
||||||
conv_fwd->Execute(src_data, filter_data, dst_data);
|
conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream);
|
||||||
} else {
|
} else {
|
||||||
// In eager mode we first write the output to temporary
|
// In eager mode we first write the output to temporary
|
||||||
// buffer in MKL format. Then we convert the data to TF format.
|
// buffer in MKL format. Then we convert the data to TF format.
|
||||||
Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
|
Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
|
||||||
tmp_tensor.flat<Toutput>().data());
|
tmp_tensor.flat<Toutput>().data());
|
||||||
conv_fwd->Execute(src_data, filter_data, tmp_data);
|
conv_fwd->Execute(src_data, filter_data, tmp_data, fwd_cpu_stream);
|
||||||
|
|
||||||
// Now we need to convert the output to TF format.
|
// Now we need to convert the output to TF format.
|
||||||
auto output_tf_md = output_mkl_shape.GetTfLayout();
|
auto output_tf_md = output_mkl_shape.GetTfLayout();
|
||||||
@ -780,12 +780,13 @@ class MklConvOp : public OpKernel {
|
|||||||
memory* dst_data_mem =
|
memory* dst_data_mem =
|
||||||
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data);
|
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data);
|
||||||
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
|
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
|
||||||
cpu_engine_);
|
cpu_engine_, context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete primitive since it is not cached.
|
// Delete primitive since it is not cached.
|
||||||
if (do_not_cache) delete conv_fwd;
|
if (do_not_cache) delete conv_fwd;
|
||||||
|
|
||||||
} catch (mkldnn::error& e) {
|
} catch (mkldnn::error& e) {
|
||||||
string error_msg = tensorflow::strings::StrCat(
|
string error_msg = tensorflow::strings::StrCat(
|
||||||
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
|
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
|
||||||
@ -970,8 +971,9 @@ class MklConvOp : public OpKernel {
|
|||||||
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
|
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
|
||||||
auto reorder_desc =
|
auto reorder_desc =
|
||||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||||
|
|
||||||
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
||||||
this->cpu_engine_);
|
this->cpu_engine_, context);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||||
@ -1097,6 +1099,7 @@ class MklConvOp : public OpKernel {
|
|||||||
filter_tf_shape, filter_mkl_shape);
|
filter_tf_shape, filter_mkl_shape);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(intel-mkl): This function does not seem to be called. Remove it.
|
||||||
// Prepare and execute net - checks for input and output reorders.
|
// Prepare and execute net - checks for input and output reorders.
|
||||||
void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
|
void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
|
||||||
MklDnnData<Tinput>* src,
|
MklDnnData<Tinput>* src,
|
||||||
@ -1185,7 +1188,7 @@ class MklConvOp : public OpKernel {
|
|||||||
// Otherwise, cache filter
|
// Otherwise, cache filter
|
||||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
|
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
|
||||||
this->cpu_engine_);
|
this->cpu_engine_, context);
|
||||||
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||||
|
|
||||||
Tensor* filter_tensor_ptr = nullptr;
|
Tensor* filter_tensor_ptr = nullptr;
|
||||||
@ -1251,9 +1254,9 @@ class MklConvOp : public OpKernel {
|
|||||||
const Tensor& cached_filter_md =
|
const Tensor& cached_filter_md =
|
||||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||||
|
|
||||||
// Check if the memory descriptor of the cached weights is same as
|
// Check if the memory descriptor of the cached weights is same as
|
||||||
// filter_md. If so, we can use the cached weights; otherwise
|
// filter_md. If so, we can use the cached weights; otherwise
|
||||||
// return nullptr.
|
// return nullptr.
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||||
#else
|
#else
|
||||||
@ -1652,7 +1655,7 @@ class MklQuantizedConv2DOp
|
|||||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
||||||
bias_attr);
|
bias_attr);
|
||||||
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
||||||
this->cpu_engine_);
|
this->cpu_engine_, context);
|
||||||
|
|
||||||
Tbias* bias_data =
|
Tbias* bias_data =
|
||||||
reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
|
reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
|
||||||
@ -1908,7 +1911,8 @@ class MklQuantizedConv2DSumReluOp
|
|||||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||||
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
||||||
reorder_attr);
|
reorder_attr);
|
||||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_);
|
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
|
||||||
|
context);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<mkldnn::memory> summand_;
|
std::shared_ptr<mkldnn::memory> summand_;
|
||||||
|
@ -165,7 +165,7 @@ class MklInputConversionOp : public OpKernel {
|
|||||||
input1_md, tensor_out, net, net_args, cpu_engine)),
|
input1_md, tensor_out, net, net_args, cpu_engine)),
|
||||||
errors::Internal(
|
errors::Internal(
|
||||||
"MklInputConversionOp: Failed to create reorder for input0"));
|
"MklInputConversionOp: Failed to create reorder for input0"));
|
||||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
|
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
|
||||||
// Input1 will be passed through
|
// Input1 will be passed through
|
||||||
ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
|
ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
|
||||||
return;
|
return;
|
||||||
@ -273,7 +273,7 @@ class MklInputConversionOp : public OpKernel {
|
|||||||
errors::Internal("MklInputConversionOp: Failed to forward "
|
errors::Internal("MklInputConversionOp: Failed to forward "
|
||||||
"input tensor to output"));
|
"input tensor to output"));
|
||||||
} else {
|
} else {
|
||||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
|
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
// -- The tensor in MKL format passes through --
|
// -- The tensor in MKL format passes through --
|
||||||
|
@ -172,7 +172,8 @@ class MklReshapeOp : public OpKernel {
|
|||||||
// shape_from != shape_to), then we just copy input tensor to
|
// shape_from != shape_to), then we just copy input tensor to
|
||||||
// output tensor with target shape (we cannot forward Mkl layout
|
// output tensor with target shape (we cannot forward Mkl layout
|
||||||
// in such case because shape has changed.)
|
// in such case because shape has changed.)
|
||||||
if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor)) {
|
if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor,
|
||||||
|
context)) {
|
||||||
} else {
|
} else {
|
||||||
OP_REQUIRES(context,
|
OP_REQUIRES(context,
|
||||||
output_tensor->CopyFrom(input_tensor, shape_to),
|
output_tensor->CopyFrom(input_tensor, shape_to),
|
||||||
|
@ -111,7 +111,8 @@ class MklToTfOp : public OpKernel {
|
|||||||
if (input.IsReorderNeeded(OUTPUT_TF_MD)) {
|
if (input.IsReorderNeeded(OUTPUT_TF_MD)) {
|
||||||
// Insert reorder between MKL layout and TensorFlow layout
|
// Insert reorder between MKL layout and TensorFlow layout
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
context, input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor),
|
context,
|
||||||
|
input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor, context),
|
||||||
errors::Internal("MklToTfOp: Failed to create input reorder"));
|
errors::Internal("MklToTfOp: Failed to create input reorder"));
|
||||||
} else {
|
} else {
|
||||||
// If not, just forward input tensor to output tensor.
|
// If not, just forward input tensor to output tensor.
|
||||||
|
@ -144,6 +144,7 @@ filegroup(
|
|||||||
"matmul_autotune.h",
|
"matmul_autotune.h",
|
||||||
"matmul_bcast.h",
|
"matmul_bcast.h",
|
||||||
"mirror_pad_mode.h",
|
"mirror_pad_mode.h",
|
||||||
|
"mkl_threadpool.h",
|
||||||
"mkl_types.h",
|
"mkl_types.h",
|
||||||
"mkl_util.h",
|
"mkl_util.h",
|
||||||
"overflow.h",
|
"overflow.h",
|
||||||
@ -273,6 +274,7 @@ filegroup(
|
|||||||
filegroup(
|
filegroup(
|
||||||
name = "mkl_util_hdrs",
|
name = "mkl_util_hdrs",
|
||||||
srcs = [
|
srcs = [
|
||||||
|
"mkl_threadpool.h",
|
||||||
"mkl_util.h",
|
"mkl_util.h",
|
||||||
],
|
],
|
||||||
visibility = ["//tensorflow/core:__pkg__"],
|
visibility = ["//tensorflow/core:__pkg__"],
|
||||||
|
138
tensorflow/core/util/mkl_threadpool.h
Normal file
138
tensorflow/core/util/mkl_threadpool.h
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
|
||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
|
||||||
|
#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
|
||||||
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "mkldnn.hpp"
|
||||||
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
#include "tensorflow/core/platform/threadpool.h"
|
||||||
|
#define EIGEN_USE_THREADS
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
using dnnl::stream_attr;
|
||||||
|
using dnnl::threadpool_iface;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Divide 'n' units of work equally among 'teams' threads. If 'n' is not
|
||||||
|
// divisible by 'teams' and has a remainder 'r', the first 'r' teams have one
|
||||||
|
// unit of work more than the rest. Returns the range of work that belongs to
|
||||||
|
// the team 'tid'.
|
||||||
|
// Parameters
|
||||||
|
// n Total number of jobs.
|
||||||
|
// team Number of workers.
|
||||||
|
// tid Current thread_id.
|
||||||
|
// n_start start of range operated by the thread.
|
||||||
|
// n_end end of the range operated by the thread.
|
||||||
|
|
||||||
|
template <typename T, typename U>
|
||||||
|
inline void balance211(T n, U team, U tid, T* n_start, T* n_end) {
|
||||||
|
if (team <= 1 || n == 0) {
|
||||||
|
*n_start = 0;
|
||||||
|
*n_end = n;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
T min_per_team = n / team;
|
||||||
|
T remainder = n - min_per_team * team; // i.e., n % teams.
|
||||||
|
*n_start = tid * min_per_team + std::min(tid, remainder);
|
||||||
|
*n_end = *n_start + min_per_team + (tid < remainder);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MklDnnThreadPool : public dnnl::threadpool_iface {
|
||||||
|
MklDnnThreadPool() = default;
|
||||||
|
|
||||||
|
MklDnnThreadPool(OpKernelContext* ctx)
|
||||||
|
: eigen_interface_(ctx->device()
|
||||||
|
->tensorflow_cpu_worker_threads()
|
||||||
|
->workers->AsEigenThreadPool()) {}
|
||||||
|
virtual int get_num_threads() override {
|
||||||
|
return eigen_interface_->NumThreads();
|
||||||
|
}
|
||||||
|
virtual bool get_in_parallel() override {
|
||||||
|
return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
|
||||||
|
}
|
||||||
|
virtual uint64_t get_flags() override { return ASYNCHRONOUS; }
|
||||||
|
virtual void parallel_for(int n,
|
||||||
|
const std::function<void(int, int)>& fn) override {
|
||||||
|
// Should never happen (handled by DNNL)
|
||||||
|
if (n == 0) return;
|
||||||
|
|
||||||
|
// Should never happen (handled by DNNL)
|
||||||
|
if (n == 1) {
|
||||||
|
fn(0, 1);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
int nthr = get_num_threads();
|
||||||
|
int njobs = std::min(n, nthr);
|
||||||
|
for (int i = 0; i < njobs; i++) {
|
||||||
|
eigen_interface_->ScheduleWithHint(
|
||||||
|
[i, n, njobs, fn]() {
|
||||||
|
int start, end;
|
||||||
|
balance211(n, njobs, i, &start, &end);
|
||||||
|
for (int j = start; j < end; j++) fn(j, n);
|
||||||
|
},
|
||||||
|
i, i + 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
~MklDnnThreadPool() {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
class MklDnnThreadPoolWrapper {
|
||||||
|
public:
|
||||||
|
static MklDnnThreadPoolWrapper& GetInstance() {
|
||||||
|
static MklDnnThreadPoolWrapper instance_;
|
||||||
|
return instance_;
|
||||||
|
}
|
||||||
|
MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) {
|
||||||
|
if (threadpool_map_.empty() ||
|
||||||
|
threadpool_map_.find(ctx->device()) == threadpool_map_.end()) {
|
||||||
|
auto tp_iface = new MklDnnThreadPool(ctx);
|
||||||
|
threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface));
|
||||||
|
return tp_iface;
|
||||||
|
} else {
|
||||||
|
auto entry = threadpool_map_.find(ctx->device());
|
||||||
|
return entry->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<DeviceBase*, MklDnnThreadPool*> threadpool_map_;
|
||||||
|
MklDnnThreadPoolWrapper() {}
|
||||||
|
MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete;
|
||||||
|
MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete;
|
||||||
|
~MklDnnThreadPoolWrapper() {
|
||||||
|
for (auto& tp : threadpool_map_) {
|
||||||
|
delete tp.second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tensorflow
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
|
#endif // INTEL_MKL
|
||||||
|
#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
|
@ -36,6 +36,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/util/env_var.h"
|
#include "tensorflow/core/util/env_var.h"
|
||||||
|
#include "tensorflow/core/util/mkl_threadpool.h"
|
||||||
#include "tensorflow/core/util/mkl_types.h"
|
#include "tensorflow/core/util/mkl_types.h"
|
||||||
#include "tensorflow/core/util/padding.h"
|
#include "tensorflow/core/util/padding.h"
|
||||||
#include "tensorflow/core/util/tensor_format.h"
|
#include "tensorflow/core/util/tensor_format.h"
|
||||||
@ -48,7 +49,6 @@ using mkldnn::padding_kind;
|
|||||||
using mkldnn::primitive;
|
using mkldnn::primitive;
|
||||||
using mkldnn::reorder;
|
using mkldnn::reorder;
|
||||||
using mkldnn::stream;
|
using mkldnn::stream;
|
||||||
|
|
||||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||||
using MemoryArgsMap = std::unordered_map<int, memory>;
|
using MemoryArgsMap = std::unordered_map<int, memory>;
|
||||||
using ReorderPd = mkldnn::reorder::primitive_desc;
|
using ReorderPd = mkldnn::reorder::primitive_desc;
|
||||||
@ -232,6 +232,27 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline mkldnn::stream* CreateStream(OpKernelContext* ctx,
|
||||||
|
const engine& engine) {
|
||||||
|
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||||
|
stream_attr tp_stream_attr(ENGINE_CPU);
|
||||||
|
if (ctx != nullptr) {
|
||||||
|
auto eigen_tp =
|
||||||
|
MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
|
||||||
|
tp_stream_attr.set_threadpool(eigen_tp);
|
||||||
|
stream* tp_stream =
|
||||||
|
new stream(engine, stream::flags::default_flags, tp_stream_attr);
|
||||||
|
return tp_stream;
|
||||||
|
} else {
|
||||||
|
stream* tp_stream = new CPU_STREAM(engine);
|
||||||
|
return tp_stream;
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
stream* tp_stream = new CPU_STREAM(engine);
|
||||||
|
return tp_stream;
|
||||||
|
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||||
|
}
|
||||||
|
|
||||||
class MklDnnShape {
|
class MklDnnShape {
|
||||||
private:
|
private:
|
||||||
typedef struct {
|
typedef struct {
|
||||||
@ -679,20 +700,21 @@ class MklDnnData;
|
|||||||
// TODO merge with the execute_primitives.
|
// TODO merge with the execute_primitives.
|
||||||
inline void ExecutePrimitive(const std::vector<primitive>& net,
|
inline void ExecutePrimitive(const std::vector<primitive>& net,
|
||||||
const std::vector<MemoryArgsMap>* net_args,
|
const std::vector<MemoryArgsMap>* net_args,
|
||||||
const engine& cpu_engine) {
|
const engine& cpu_engine,
|
||||||
|
OpKernelContext* context = nullptr) {
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
DCHECK(net_args);
|
DCHECK(net_args);
|
||||||
DCHECK_EQ(net.size(), net_args->size());
|
DCHECK_EQ(net.size(), net_args->size());
|
||||||
stream cpu_stream(cpu_engine);
|
stream* cpu_stream = CreateStream(context, cpu_engine);
|
||||||
for (size_t i = 0; i < net.size(); ++i) {
|
for (size_t i = 0; i < net.size(); ++i) {
|
||||||
net.at(i).execute(cpu_stream, net_args->at(i));
|
net.at(i).execute(*cpu_stream, net_args->at(i));
|
||||||
}
|
}
|
||||||
cpu_stream.wait();
|
cpu_stream->wait();
|
||||||
|
delete cpu_stream;
|
||||||
#else
|
#else
|
||||||
stream(stream::kind::eager_nostore).submit(net).wait();
|
stream(stream::kind::eager_nostore).submit(net).wait();
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline Status ConvertMklToTF(OpKernelContext* context,
|
inline Status ConvertMklToTF(OpKernelContext* context,
|
||||||
const Tensor& input_mkl_tensor,
|
const Tensor& input_mkl_tensor,
|
||||||
@ -731,7 +753,7 @@ inline Status ConvertMklToTF(OpKernelContext* context,
|
|||||||
return Status(error::Code::INTERNAL,
|
return Status(error::Code::INTERNAL,
|
||||||
"ConvertMklToTF(): Failed to create reorder for input");
|
"ConvertMklToTF(): Failed to create reorder for input");
|
||||||
}
|
}
|
||||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
|
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
|
||||||
} else {
|
} else {
|
||||||
// If not, just forward input tensor to output tensor.
|
// If not, just forward input tensor to output tensor.
|
||||||
bool status =
|
bool status =
|
||||||
@ -1301,8 +1323,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
|
|||||||
|
|
||||||
inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
|
inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
|
||||||
const memory& src_mem,
|
const memory& src_mem,
|
||||||
const memory& dst_mem,
|
const memory& dst_mem, const engine& engine,
|
||||||
const engine& engine) {
|
OpKernelContext* ctx = nullptr) {
|
||||||
std::vector<primitive> net;
|
std::vector<primitive> net;
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
net.push_back(mkldnn::reorder(reorder_desc));
|
net.push_back(mkldnn::reorder(reorder_desc));
|
||||||
@ -1311,7 +1333,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
|
|||||||
#else
|
#else
|
||||||
net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
|
net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
ExecutePrimitive(net, NET_ARGS_PTR, engine);
|
ExecutePrimitive(net, NET_ARGS_PTR, engine, ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
class MklReorderPrimitive;
|
class MklReorderPrimitive;
|
||||||
@ -1629,22 +1651,26 @@ class MklDnnData {
|
|||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
inline bool CheckReorderToOpMem(const memory::desc& op_md,
|
inline bool CheckReorderToOpMem(const memory::desc& op_md,
|
||||||
const engine& engine) {
|
const engine& engine,
|
||||||
|
OpKernelContext* context = nullptr) {
|
||||||
DCHECK(user_memory_);
|
DCHECK(user_memory_);
|
||||||
if (IsReorderNeeded(op_md)) {
|
if (IsReorderNeeded(op_md)) {
|
||||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||||
// primitive reuse don't allow two same reorder prim in
|
// primitive reuse don't allow two same reorder prim in
|
||||||
// one stream, so submit it immediately
|
// one stream, so submit it immediately
|
||||||
reorder_memory_ = new memory(op_md, engine);
|
reorder_memory_ = new memory(op_md, engine);
|
||||||
std::vector<primitive> net;
|
|
||||||
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
|
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
|
||||||
|
std::shared_ptr<stream> cpu_stream;
|
||||||
|
cpu_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||||
|
std::vector<primitive> net;
|
||||||
net.push_back(*(prim->GetPrimitive()));
|
net.push_back(*(prim->GetPrimitive()));
|
||||||
std::vector<MemoryArgsMap> net_args;
|
std::vector<MemoryArgsMap> net_args;
|
||||||
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
|
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
|
||||||
{MKLDNN_ARG_TO, *reorder_memory_}});
|
{MKLDNN_ARG_TO, *reorder_memory_}});
|
||||||
execute_primitives(net, prim->GetStream(), net_args);
|
execute_primitives(net, cpu_stream, net_args);
|
||||||
#else
|
#else
|
||||||
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
|
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
|
||||||
|
OpKernelContext* ctx = nullptr) {
|
||||||
CHECK_NOTNULL(user_memory_);
|
CHECK_NOTNULL(user_memory_);
|
||||||
if (IsReorderNeeded(op_pd)) {
|
if (IsReorderNeeded(op_pd)) {
|
||||||
reorder_memory_ = new memory(op_pd);
|
reorder_memory_ = new memory(op_pd);
|
||||||
@ -1708,7 +1734,8 @@ class MklDnnData {
|
|||||||
/// TODO(bhavanis): Need to use reorder cache here for better performance.
|
/// TODO(bhavanis): Need to use reorder cache here for better performance.
|
||||||
inline bool CheckReorderToOpMem(const memory::desc& op_md,
|
inline bool CheckReorderToOpMem(const memory::desc& op_md,
|
||||||
void* reorder_data_handle,
|
void* reorder_data_handle,
|
||||||
const engine& engine) {
|
const engine& engine,
|
||||||
|
OpKernelContext* context = nullptr) {
|
||||||
DCHECK(reorder_data_handle);
|
DCHECK(reorder_data_handle);
|
||||||
DCHECK(user_memory_);
|
DCHECK(user_memory_);
|
||||||
if (IsReorderNeeded(op_md)) {
|
if (IsReorderNeeded(op_md)) {
|
||||||
@ -1716,16 +1743,19 @@ class MklDnnData {
|
|||||||
// primitive reuse don't allow two same reorder prim in
|
// primitive reuse don't allow two same reorder prim in
|
||||||
// one stream, so submit it immediately
|
// one stream, so submit it immediately
|
||||||
reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
|
reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
|
||||||
std::vector<primitive> net;
|
|
||||||
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
|
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
|
||||||
|
std::shared_ptr<stream> cpu_stream;
|
||||||
|
cpu_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||||
|
std::vector<primitive> net;
|
||||||
net.push_back(*(prim->GetPrimitive()));
|
net.push_back(*(prim->GetPrimitive()));
|
||||||
std::vector<MemoryArgsMap> net_args;
|
std::vector<MemoryArgsMap> net_args;
|
||||||
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
|
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
|
||||||
{MKLDNN_ARG_TO, *reorder_memory_}});
|
{MKLDNN_ARG_TO, *reorder_memory_}});
|
||||||
execute_primitives(net, prim->GetStream(), net_args);
|
execute_primitives(net, cpu_stream, net_args);
|
||||||
#else
|
#else
|
||||||
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
|
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
|
||||||
void* reorder_data_handle) {
|
void* reorder_data_handle,
|
||||||
|
OpKernelContext* context = nullptr) {
|
||||||
CHECK_NOTNULL(reorder_data_handle);
|
CHECK_NOTNULL(reorder_data_handle);
|
||||||
CHECK_NOTNULL(user_memory_);
|
CHECK_NOTNULL(user_memory_);
|
||||||
if (IsReorderNeeded(op_pd)) {
|
if (IsReorderNeeded(op_pd)) {
|
||||||
@ -1778,13 +1808,14 @@ class MklDnnData {
|
|||||||
/// remove
|
/// remove
|
||||||
/// slow path in the future
|
/// slow path in the future
|
||||||
inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd,
|
inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd,
|
||||||
Tensor* reorder_tensor) {
|
Tensor* reorder_tensor,
|
||||||
|
OpKernelContext* ctx = nullptr) {
|
||||||
DCHECK(reorder_tensor);
|
DCHECK(reorder_tensor);
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
|
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
|
||||||
*cpu_engine_);
|
*cpu_engine_, ctx);
|
||||||
#else
|
#else
|
||||||
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
|
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), ctx);
|
||||||
#endif // ENABLE_MKLDNN_V1
|
#endif // ENABLE_MKLDNN_V1
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1843,7 +1874,7 @@ class MklDnnData {
|
|||||||
/// TODO: this is a faster path with reorder primitive cache compared with
|
/// TODO: this is a faster path with reorder primitive cache compared with
|
||||||
/// InsertReorderToUserMem(net, net_args), will remove
|
/// InsertReorderToUserMem(net, net_args), will remove
|
||||||
/// slow path in the future
|
/// slow path in the future
|
||||||
inline void InsertReorderToUserMem() {
|
inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
|
||||||
DCHECK(user_memory_);
|
DCHECK(user_memory_);
|
||||||
DCHECK(reorder_memory_);
|
DCHECK(reorder_memory_);
|
||||||
DCHECK(cpu_engine_);
|
DCHECK(cpu_engine_);
|
||||||
@ -1857,8 +1888,8 @@ class MklDnnData {
|
|||||||
net_args.push_back(
|
net_args.push_back(
|
||||||
{{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
|
{{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
|
||||||
std::shared_ptr<stream> cpu_stream;
|
std::shared_ptr<stream> cpu_stream;
|
||||||
cpu_stream.reset(new stream(*cpu_engine_));
|
cpu_stream.reset(CreateStream(ctx, prim->GetEngine()));
|
||||||
execute_primitives(net, prim->GetStream(), net_args);
|
execute_primitives(net, cpu_stream, net_args);
|
||||||
#else
|
#else
|
||||||
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
|
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
|
||||||
ExecutePrimitive(net, NET_ARGS_PTR, *cpu_engine_);
|
ExecutePrimitive(net, NET_ARGS_PTR, *cpu_engine_);
|
||||||
@ -1870,9 +1901,12 @@ class MklDnnData {
|
|||||||
class MklPrimitive {
|
class MklPrimitive {
|
||||||
public:
|
public:
|
||||||
virtual ~MklPrimitive() {}
|
virtual ~MklPrimitive() {}
|
||||||
|
MklPrimitive() {}
|
||||||
|
MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
|
||||||
// Dummy data which MKL DNN never operates on
|
// Dummy data which MKL DNN never operates on
|
||||||
unsigned char* DummyData = nullptr;
|
unsigned char* DummyData = nullptr;
|
||||||
|
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||||
|
const engine& GetEngine() { return cpu_engine_; }
|
||||||
};
|
};
|
||||||
|
|
||||||
const mkldnn::memory::dims NONE_DIMS = {};
|
const mkldnn::memory::dims NONE_DIMS = {};
|
||||||
@ -2058,7 +2092,8 @@ class FactoryKeyCreator {
|
|||||||
|
|
||||||
class MklReorderPrimitive : public MklPrimitive {
|
class MklReorderPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklReorderPrimitive(const memory* from, const memory* to) {
|
explicit MklReorderPrimitive(const memory* from, const memory* to)
|
||||||
|
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||||
Setup(from, to);
|
Setup(from, to);
|
||||||
}
|
}
|
||||||
~MklReorderPrimitive() {}
|
~MklReorderPrimitive() {}
|
||||||
@ -2081,7 +2116,6 @@ class MklReorderPrimitive : public MklPrimitive {
|
|||||||
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
||||||
} context_;
|
} context_;
|
||||||
|
|
||||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
|
||||||
std::shared_ptr<mkldnn::stream> stream_;
|
std::shared_ptr<mkldnn::stream> stream_;
|
||||||
|
|
||||||
void Setup(const memory* from, const memory* to) {
|
void Setup(const memory* from, const memory* to) {
|
||||||
|
Loading…
Reference in New Issue
Block a user