[Intel MKL] Adding DNNL ops (part 1) supporting threadpool work

This commit is contained in:
sshiddib 2020-06-15 13:14:03 -07:00 committed by Sharada Shiddibhavi
parent 7c694e10b5
commit 2fe65568b1
9 changed files with 143 additions and 9 deletions

View File

@ -114,6 +114,21 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
void Execute(const T* src_data, const T* diff_filter_data,
const T* diff_bias_data, const T* diff_dst_data,
std::shared_ptr<stream> bwd_filter_stream) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream);
context_.diff_filter_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_filter_data)),
*bwd_filter_stream);
if (diff_bias_data != nullptr) {
context_.diff_bias_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_bias_data)),
*bwd_filter_stream);
}
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_filter_stream);
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_filter_mem->set_data_handle(
@ -124,7 +139,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
}
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
context_.bwd_filter_primitives_args);

View File

@ -116,13 +116,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
void Execute(const T* diff_src_data, const T* filter_data,
const T* diff_dst_data,
std::shared_ptr<stream> bwd_input_stream) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.diff_src_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_src_data)), *bwd_input_stream);
context_.filter_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(filter_data)), *bwd_input_stream);
context_.diff_dst_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_dst_data)), *bwd_input_stream);
#else
context_.diff_src_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_src_data)));
context_.filter_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(filter_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<T*>(const_cast<T*>(diff_dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
context_.bwd_input_primitives_args);

View File

@ -110,6 +110,19 @@ class MklConvFwdPrimitive : public MklPrimitive {
void Execute(const Tinput* src_data, const Tfilter* filter_data,
const Tbias* bias_data, const Toutput* dst_data,
std::shared_ptr<stream> fwd_stream) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
context_.filter_mem->set_data_handle(
static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream);
if (bias_data != nullptr) {
context_.bias_mem->set_data_handle(
static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream);
}
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream);
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<Tinput*>(src_data)));
context_.filter_mem->set_data_handle(
@ -120,6 +133,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
}
context_.dst_mem->set_data_handle(
static_cast<void*>(const_cast<Toutput*>(dst_data)));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());

View File

@ -94,6 +94,28 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
void Execute(const T* src_data, const U* weights_data, T* dst_data,
U* mean_data, U* variance_data,
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
*fwd_stream);
if (IS_SET(use_scale_shift))
context_.weights_mem->set_data_handle(
static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream);
if ((context_.pkind == prop_kind::forward_training) ||
(IS_SET(use_global_stats))) {
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data),
*fwd_stream);
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data),
*fwd_stream);
}
if (workspace_data != nullptr) {
context_.ws_mem->set_data_handle(workspace_data, *fwd_stream);
}
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -110,6 +132,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
if (workspace_data != nullptr) {
context_.ws_mem->set_data_handle(workspace_data);
}
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
// Execute batch-normalization forward primitives.
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
@ -503,6 +526,27 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
U* diff_weights_data, U* res_space_data,
std::shared_ptr<stream> bwd_stream) {
// TODO: Create a common function and avoid the duplicate code
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
context_.mean_mem->set_data_handle(
static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream);
context_.variance_mem->set_data_handle(
static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream);
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
if (IS_SET(use_scale_shift)) {
context_.weights_mem->set_data_handle(
static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream);
context_.diff_weights_mem->set_data_handle(
static_cast<void*>(diff_weights_data), *bwd_stream);
}
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
*bwd_stream);
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.mean_mem->set_data_handle(
@ -520,7 +564,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
}
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
// Execute backward batch-normalization primitives.
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());

View File

@ -127,6 +127,17 @@ template <typename T>
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
void* ws_data,
std::shared_ptr<stream> fwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), *fwd_stream);
if (context_.alg_kind == ALGORITHM::pooling_max &&
context_.prop_kind ==
prop_kind::forward_training) { // Max pooling must have workspace.
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data, *fwd_stream);
}
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
@ -136,7 +147,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
#else
@ -269,6 +280,16 @@ template <typename T>
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
T* diff_src_data, const void* ws_data,
std::shared_ptr<stream> bwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
*bwd_stream);
if (context_.alg_kind == ALGORITHM::pooling_max) {
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data), *bwd_stream);
}
#else
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
@ -276,7 +297,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
#else

View File

@ -88,8 +88,13 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
void Execute(void* src_data, void* dst_data,
std::shared_ptr<stream> reorder_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(src_data, *reorder_stream);
context_.dst_mem->set_data_handle(dst_data, *reorder_stream);
#else
context_.src_mem->set_data_handle(src_data);
context_.dst_mem->set_data_handle(dst_data);
#endif // ENABLE_MKLDNN_THREADPOOL
#ifndef ENABLE_MKLDNN_V1
reorder_stream->submit(context_.net);
#else

View File

@ -79,10 +79,16 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
*fwd_stream);
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
@ -293,12 +299,20 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
// diff_src_data: output data buffer of diff_src
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
std::shared_ptr<stream> bwd_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
*bwd_stream);
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.bwd_primitives.size(),
context_.bwd_primitives_args.size());

View File

@ -189,9 +189,15 @@ class MklSlicePrimitive : public MklPrimitive {
void Execute(const MklSliceParams& sliceParams,
std::shared_ptr<stream> slice_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
*slice_stream);
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
*slice_stream);
#else
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
execute_primitives(context_.slice_primitives, slice_stream,
context_.slice_primitives_args);

View File

@ -59,10 +59,16 @@ class MklSoftmaxPrimitive : public MklPrimitive {
// dst_data: output data buffer of dst
void Execute(const T* src_data, T* dst_data,
std::shared_ptr<stream> fwd_cpu_stream) {
#ifdef ENABLE_MKLDNN_THREADPOOL
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream);
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
*fwd_cpu_stream);
#else
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#endif // ENABLE_MKLDNN_THREADPOOL
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,