[Intel MKL] Adding DNNL ops (part 1) supporting threadpool work
This commit is contained in:
parent
7c694e10b5
commit
2fe65568b1
@ -114,6 +114,21 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
void Execute(const T* src_data, const T* diff_filter_data,
|
||||
const T* diff_bias_data, const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_filter_stream) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *bwd_filter_stream);
|
||||
context_.diff_filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_filter_data)),
|
||||
*bwd_filter_stream);
|
||||
if (diff_bias_data != nullptr) {
|
||||
context_.diff_bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_bias_data)),
|
||||
*bwd_filter_stream);
|
||||
}
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_filter_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_filter_mem->set_data_handle(
|
||||
@ -124,7 +139,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive {
|
||||
}
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_filter_primitives, bwd_filter_stream,
|
||||
context_.bwd_filter_primitives_args);
|
||||
|
@ -116,13 +116,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive {
|
||||
void Execute(const T* diff_src_data, const T* filter_data,
|
||||
const T* diff_dst_data,
|
||||
std::shared_ptr<stream> bwd_input_stream) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.diff_src_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_src_data)), *bwd_input_stream);
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(filter_data)), *bwd_input_stream);
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_data)), *bwd_input_stream);
|
||||
#else
|
||||
context_.diff_src_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(filter_data)));
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<T*>(const_cast<T*>(diff_dst_data)));
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_input_primitives, bwd_input_stream,
|
||||
context_.bwd_input_primitives_args);
|
||||
|
@ -110,6 +110,19 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||
const Tbias* bias_data, const Toutput* dst_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)), *fwd_stream);
|
||||
context_.filter_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tfilter*>(filter_data)), *fwd_stream);
|
||||
if (bias_data != nullptr) {
|
||||
context_.bias_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tbias*>(bias_data)), *fwd_stream);
|
||||
}
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Toutput*>(dst_data)), *fwd_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
@ -120,6 +133,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
context_.dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Toutput*>(dst_data)));
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
|
@ -94,6 +94,28 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
void Execute(const T* src_data, const U* weights_data, T* dst_data,
|
||||
U* mean_data, U* variance_data,
|
||||
std::shared_ptr<stream> fwd_stream, U* workspace_data) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
||||
*fwd_stream);
|
||||
|
||||
if (IS_SET(use_scale_shift))
|
||||
context_.weights_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<U*>(weights_data)), *fwd_stream);
|
||||
|
||||
if ((context_.pkind == prop_kind::forward_training) ||
|
||||
(IS_SET(use_global_stats))) {
|
||||
context_.mean_mem->set_data_handle(static_cast<void*>(mean_data),
|
||||
*fwd_stream);
|
||||
context_.variance_mem->set_data_handle(static_cast<void*>(variance_data),
|
||||
*fwd_stream);
|
||||
}
|
||||
if (workspace_data != nullptr) {
|
||||
context_.ws_mem->set_data_handle(workspace_data, *fwd_stream);
|
||||
}
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -110,6 +132,7 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive {
|
||||
if (workspace_data != nullptr) {
|
||||
context_.ws_mem->set_data_handle(workspace_data);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Execute batch-normalization forward primitives.
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||
@ -503,6 +526,27 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
const T* diff_dst_data, const U* weights_data, T* diff_src_data,
|
||||
U* diff_weights_data, U* res_space_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
// TODO: Create a common function and avoid the duplicate code
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
|
||||
context_.mean_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<U*>(mean_data)), *bwd_stream);
|
||||
context_.variance_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<U*>(variance_data)), *bwd_stream);
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
|
||||
|
||||
if (IS_SET(use_scale_shift)) {
|
||||
context_.weights_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<U*>(weights_data)), *bwd_stream);
|
||||
context_.diff_weights_mem->set_data_handle(
|
||||
static_cast<void*>(diff_weights_data), *bwd_stream);
|
||||
}
|
||||
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
|
||||
*bwd_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.mean_mem->set_data_handle(
|
||||
@ -520,7 +564,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Execute backward batch-normalization primitives.
|
||||
DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size());
|
||||
|
@ -127,6 +127,17 @@ template <typename T>
|
||||
void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||
void* ws_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data), *fwd_stream);
|
||||
if (context_.alg_kind == ALGORITHM::pooling_max &&
|
||||
context_.prop_kind ==
|
||||
prop_kind::forward_training) { // Max pooling must have workspace.
|
||||
DCHECK(ws_data != nullptr);
|
||||
context_.ws_mem->set_data_handle(ws_data, *fwd_stream);
|
||||
}
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
@ -136,7 +147,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
|
||||
DCHECK(ws_data != nullptr);
|
||||
context_.ws_mem->set_data_handle(ws_data);
|
||||
}
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.fwd_primitives, fwd_stream, context_.net_args);
|
||||
#else
|
||||
@ -269,6 +280,16 @@ template <typename T>
|
||||
void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||
T* diff_src_data, const void* ws_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
|
||||
*bwd_stream);
|
||||
if (context_.alg_kind == ALGORITHM::pooling_max) {
|
||||
DCHECK(ws_data != nullptr);
|
||||
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data), *bwd_stream);
|
||||
}
|
||||
#else
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||
@ -276,7 +297,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
|
||||
DCHECK(ws_data != nullptr);
|
||||
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
|
||||
}
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.bwd_primitives, bwd_stream, context_.net_args);
|
||||
#else
|
||||
|
@ -88,8 +88,13 @@ class MklReorderWithScalePrimitive : public MklPrimitive {
|
||||
|
||||
void Execute(void* src_data, void* dst_data,
|
||||
std::shared_ptr<stream> reorder_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(src_data, *reorder_stream);
|
||||
context_.dst_mem->set_data_handle(dst_data, *reorder_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(src_data);
|
||||
context_.dst_mem->set_data_handle(dst_data);
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
reorder_stream->submit(context_.net);
|
||||
#else
|
||||
|
@ -79,10 +79,16 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const T* src_data, T* dst_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_stream);
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
||||
*fwd_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
@ -293,12 +299,20 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
// diff_src_data: output data buffer of diff_src
|
||||
void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data,
|
||||
std::shared_ptr<stream> bwd_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *bwd_stream);
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)), *bwd_stream);
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data),
|
||||
*bwd_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.bwd_primitives.size(),
|
||||
context_.bwd_primitives_args.size());
|
||||
|
@ -189,9 +189,15 @@ class MklSlicePrimitive : public MklPrimitive {
|
||||
|
||||
void Execute(const MklSliceParams& sliceParams,
|
||||
std::shared_ptr<stream> slice_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle(),
|
||||
*slice_stream);
|
||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle(),
|
||||
*slice_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(sliceParams.from->get_data_handle());
|
||||
context_.dst_mem->set_data_handle(sliceParams.to->get_data_handle());
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
execute_primitives(context_.slice_primitives, slice_stream,
|
||||
context_.slice_primitives_args);
|
||||
|
@ -59,10 +59,16 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const T* src_data, T* dst_data,
|
||||
std::shared_ptr<stream> fwd_cpu_stream) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)), *fwd_cpu_stream);
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data),
|
||||
*fwd_cpu_stream);
|
||||
#else
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_net_args.size());
|
||||
execute_primitives(context_.fwd_primitives, fwd_cpu_stream,
|
||||
|
Loading…
Reference in New Issue
Block a user