Boiler plate code for dnnl threadpool and mkl_conv_ops.cc
This commit is contained in:
parent
765157843f
commit
dbd14b1587
tensorflow/core
@ -101,7 +101,7 @@ int32 NumIntraOpThreadsFromEnvironment() {
|
||||
const char* val = std::getenv("TF_NUM_INTRAOP_THREADS");
|
||||
return (val && strings::safe_strto32(val, &num)) ? num : 0;
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef INTEL_MKL
|
||||
int32 OMPThreadsFromEnvironment() {
|
||||
// 1) std::getenv is thread-safe (as long as no other function modifies the
|
||||
@ -123,12 +123,14 @@ int32 DefaultNumIntraOpThreads() {
|
||||
return port::MaxParallelism();
|
||||
}
|
||||
#endif // INTEL_MKL
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
||||
const int32 inter_op = options.config.inter_op_parallelism_threads();
|
||||
if (inter_op > 0) return inter_op;
|
||||
const int32 env_inter_op = GetEnvNumInterOpThreads();
|
||||
if (env_inter_op > 0) return env_inter_op;
|
||||
|
||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef INTEL_MKL
|
||||
if (!DisableMKL()) {
|
||||
// MKL library executes ops in parallel using OMP threads.
|
||||
@ -151,6 +153,7 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
|
||||
return mkl_inter_op;
|
||||
}
|
||||
#endif // INTEL_MKL
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
return DefaultNumInterOpThreads();
|
||||
}
|
||||
|
||||
|
@ -50,6 +50,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
||||
name, DEVICE_CPU, memory_limit, locality)),
|
||||
allocator_(allocator),
|
||||
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
|
||||
#ifndef ENABLE_MKLDNN_THREADPOOL
|
||||
#ifdef INTEL_MKL
|
||||
// Early return when MKL is disabled
|
||||
if (DisableMKL()) return;
|
||||
@ -70,6 +71,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
|
||||
}
|
||||
#endif // _OPENMP
|
||||
#endif // INTEL_MKL
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
}
|
||||
|
||||
ThreadPoolDevice::~ThreadPoolDevice() {}
|
||||
|
@ -24,8 +24,8 @@ limitations under the License.
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
@ -51,12 +51,10 @@ limitations under the License.
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
using ConvFwdPd = mkldnn::convolution_forward::primitive_desc;
|
||||
using ReorderPd = mkldnn::reorder::primitive_desc;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This structure aggregates multiple inputs to Conv2DFwd* methods.
|
||||
struct MklConvFwdParams {
|
||||
memory::dims src_dims;
|
||||
@ -96,14 +94,12 @@ template <typename Tinput, typename Tfilter, typename Tbias, typename Toutput>
|
||||
class MklConvFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
// Create convolution primitive
|
||||
if (context_.conv_fwd == nullptr) {
|
||||
Setup(convFwdDims);
|
||||
}
|
||||
}
|
||||
|
||||
~MklConvFwdPrimitive() {}
|
||||
|
||||
// Convolution forward execute with bias
|
||||
@ -112,7 +108,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
// bias_data: input data buffer of bias
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||
const Tbias* bias_data, const Toutput* dst_data) {
|
||||
const Tbias* bias_data, const Toutput* dst_data,
|
||||
std::shared_ptr<stream> fwd_stream) {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<Tinput*>(src_data)));
|
||||
context_.filter_mem->set_data_handle(
|
||||
@ -127,11 +124,11 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
||||
context_.fwd_primitives.at(i).execute(*fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
}
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// After execution, set data handle back
|
||||
@ -148,8 +145,8 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
// filter_data: input data buffer of filter (weights)
|
||||
// dst_data: output data buffer of dst
|
||||
void Execute(const Tinput* src_data, const Tfilter* filter_data,
|
||||
const Toutput* dst_data) {
|
||||
Execute(src_data, filter_data, nullptr, dst_data);
|
||||
const Toutput* dst_data, std::shared_ptr<stream> fwd_stream) {
|
||||
Execute(src_data, filter_data, nullptr, dst_data, fwd_stream);
|
||||
}
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
@ -191,7 +188,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<ConvFwdPd> fwd_pd;
|
||||
std::shared_ptr<mkldnn::primitive> conv_fwd;
|
||||
|
||||
std::shared_ptr<mkldnn::stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
@ -213,8 +209,7 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
filter_md(nullptr),
|
||||
bias_md(nullptr),
|
||||
fwd_pd(nullptr),
|
||||
conv_fwd(nullptr),
|
||||
fwd_stream(nullptr) {
|
||||
conv_fwd(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
@ -346,7 +341,6 @@ class MklConvFwdPrimitive : public MklPrimitive {
|
||||
}
|
||||
|
||||
struct ConvFwdContext context_;
|
||||
engine cpu_engine_;
|
||||
};
|
||||
|
||||
// TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory.
|
||||
@ -496,17 +490,15 @@ class MklConvOp : public OpKernel {
|
||||
OP_REQUIRES(context, dilations_.size() == 5,
|
||||
errors::InvalidArgument("Dilation rates field must "
|
||||
"specify 5 dimensions"));
|
||||
OP_REQUIRES(context,
|
||||
(GetTensorDim(dilations_, data_format_, 'N') == 1 &&
|
||||
GetTensorDim(dilations_, data_format_, 'C') == 1),
|
||||
OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
|
||||
GetTensorDim(dilations_, data_format_, 'C') == 1),
|
||||
errors::InvalidArgument(
|
||||
"Current implementation does not yet support "
|
||||
"dilations rates in the batch and depth dimensions."));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
(GetTensorDim(dilations_, data_format_, '0') > 0 &&
|
||||
GetTensorDim(dilations_, data_format_, '1') > 0 &&
|
||||
GetTensorDim(dilations_, data_format_, '2') > 0),
|
||||
context, (GetTensorDim(dilations_, data_format_, '0') > 0 &&
|
||||
GetTensorDim(dilations_, data_format_, '1') > 0 &&
|
||||
GetTensorDim(dilations_, data_format_, '2') > 0),
|
||||
errors::InvalidArgument("Dilated rates should be larger than 0."));
|
||||
}
|
||||
}
|
||||
@ -678,11 +670,9 @@ class MklConvOp : public OpKernel {
|
||||
|
||||
// TODO(mdfaijul): Extend the basic parameters for data types and fusions
|
||||
this->ExtendConvFwdParams(context, convFwdDims);
|
||||
|
||||
conv_fwd =
|
||||
MklConvFwdPrimitiveFactory<Tinput, Tfilter, Tbias, Ttemp_output>::Get(
|
||||
convFwdDims, do_not_cache);
|
||||
|
||||
// Allocate output tensors `output_tensor` and `filter_out_tensor`
|
||||
MklDnnShape output_mkl_shape;
|
||||
std::shared_ptr<ConvFwdPd> conv_fwd_pd = conv_fwd->GetPrimitiveDesc();
|
||||
@ -703,8 +693,10 @@ class MklConvOp : public OpKernel {
|
||||
Tinput* src_data = nullptr;
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
|
||||
src.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
src_data = static_cast<Tinput*>(src.GetOpMem().get_data_handle());
|
||||
} else {
|
||||
src_data = static_cast<Tinput*>(
|
||||
@ -735,13 +727,16 @@ class MklConvOp : public OpKernel {
|
||||
if (!is_filter_cached) {
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
if (filter_out_tensor == nullptr) {
|
||||
filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_));
|
||||
filter.CheckReorderToOpMem(
|
||||
MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
cpu_engine_),
|
||||
context);
|
||||
} else {
|
||||
filter.CheckReorderToOpMem(
|
||||
GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd),
|
||||
DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor),
|
||||
cpu_engine_));
|
||||
cpu_engine_),
|
||||
context);
|
||||
}
|
||||
filter_data =
|
||||
static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||
@ -752,20 +747,23 @@ class MklConvOp : public OpKernel {
|
||||
}
|
||||
|
||||
// Execute convolution
|
||||
std::shared_ptr<stream> fwd_cpu_stream;
|
||||
fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine()));
|
||||
if (fuse_biasadd_) {
|
||||
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
|
||||
Tbias* bias_data =
|
||||
this->GetBiasHandle(context, conv_fwd_pd, bias_tensor);
|
||||
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
|
||||
conv_fwd->Execute(src_data, filter_data, bias_data, dst_data,
|
||||
fwd_cpu_stream);
|
||||
} else {
|
||||
if (!eager_mode) {
|
||||
conv_fwd->Execute(src_data, filter_data, dst_data);
|
||||
conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream);
|
||||
} else {
|
||||
// In eager mode we first write the output to temporary
|
||||
// buffer in MKL format. Then we convert the data to TF format.
|
||||
Ttemp_output* tmp_data = reinterpret_cast<Ttemp_output*>(
|
||||
tmp_tensor.flat<Toutput>().data());
|
||||
conv_fwd->Execute(src_data, filter_data, tmp_data);
|
||||
conv_fwd->Execute(src_data, filter_data, tmp_data, fwd_cpu_stream);
|
||||
|
||||
// Now we need to convert the output to TF format.
|
||||
auto output_tf_md = output_mkl_shape.GetTfLayout();
|
||||
@ -780,12 +778,13 @@ class MklConvOp : public OpKernel {
|
||||
memory* dst_data_mem =
|
||||
new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data);
|
||||
CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem,
|
||||
cpu_engine_);
|
||||
cpu_engine_, context);
|
||||
}
|
||||
}
|
||||
|
||||
// Delete primitive since it is not cached.
|
||||
if (do_not_cache) delete conv_fwd;
|
||||
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = tensorflow::strings::StrCat(
|
||||
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
|
||||
@ -970,8 +969,9 @@ class MklConvOp : public OpKernel {
|
||||
new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf));
|
||||
auto reorder_desc =
|
||||
REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_);
|
||||
|
||||
CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_,
|
||||
this->cpu_engine_);
|
||||
this->cpu_engine_, context);
|
||||
}
|
||||
} else {
|
||||
AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor,
|
||||
@ -1097,6 +1097,8 @@ class MklConvOp : public OpKernel {
|
||||
filter_tf_shape, filter_mkl_shape);
|
||||
}
|
||||
|
||||
// TODO(intel-mkl): This function does not seem to be called. Remove it.
|
||||
// LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot
|
||||
// Prepare and execute net - checks for input and output reorders.
|
||||
void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc,
|
||||
MklDnnData<Tinput>* src,
|
||||
@ -1185,7 +1187,7 @@ class MklConvOp : public OpKernel {
|
||||
// Otherwise, cache filter
|
||||
filter.SetUsrMem(filter_md, &filter_tensor);
|
||||
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(),
|
||||
this->cpu_engine_);
|
||||
this->cpu_engine_, context);
|
||||
filter_data = static_cast<Tfilter*>(filter.GetOpMem().get_data_handle());
|
||||
|
||||
Tensor* filter_tensor_ptr = nullptr;
|
||||
@ -1251,9 +1253,9 @@ class MklConvOp : public OpKernel {
|
||||
const Tensor& cached_filter_md =
|
||||
*cached_filter_md_ptensor_.AccessTensor(context);
|
||||
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
// Check if the memory descriptor of the cached weights is same as
|
||||
// filter_md. If so, we can use the cached weights; otherwise
|
||||
// return nullptr.
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
if (filter_md == *static_cast<memory::desc*>(cached_filter_md.data())) {
|
||||
#else
|
||||
@ -1600,7 +1602,7 @@ class MklQuantizedConv2DOp
|
||||
input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_,
|
||||
bias_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_,
|
||||
this->cpu_engine_);
|
||||
this->cpu_engine_, context);
|
||||
|
||||
Tbias* bias_data =
|
||||
reinterpret_cast<Tbias*>(scaled_bias_->get_data_handle());
|
||||
@ -1856,7 +1858,8 @@ class MklQuantizedConv2DSumReluOp
|
||||
auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR(
|
||||
SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_,
|
||||
reorder_attr);
|
||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_);
|
||||
CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_,
|
||||
context);
|
||||
}
|
||||
|
||||
std::shared_ptr<mkldnn::memory> summand_;
|
||||
|
@ -165,7 +165,7 @@ class MklInputConversionOp : public OpKernel {
|
||||
input1_md, tensor_out, net, net_args, cpu_engine)),
|
||||
errors::Internal(
|
||||
"MklInputConversionOp: Failed to create reorder for input0"));
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
|
||||
// Input1 will be passed through
|
||||
ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1);
|
||||
return;
|
||||
@ -273,7 +273,7 @@ class MklInputConversionOp : public OpKernel {
|
||||
errors::Internal("MklInputConversionOp: Failed to forward "
|
||||
"input tensor to output"));
|
||||
} else {
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
|
||||
}
|
||||
|
||||
// -- The tensor in MKL format passes through --
|
||||
|
@ -172,7 +172,8 @@ class MklReshapeOp : public OpKernel {
|
||||
// shape_from != shape_to), then we just copy input tensor to
|
||||
// output tensor with target shape (we cannot forward Mkl layout
|
||||
// in such case because shape has changed.)
|
||||
if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor)) {
|
||||
if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor,
|
||||
context)) {
|
||||
} else {
|
||||
OP_REQUIRES(context,
|
||||
output_tensor->CopyFrom(input_tensor, shape_to),
|
||||
|
@ -111,7 +111,8 @@ class MklToTfOp : public OpKernel {
|
||||
if (input.IsReorderNeeded(OUTPUT_TF_MD)) {
|
||||
// Insert reorder between MKL layout and TensorFlow layout
|
||||
OP_REQUIRES(
|
||||
context, input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor),
|
||||
context,
|
||||
input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor, context),
|
||||
errors::Internal("MklToTfOp: Failed to create input reorder"));
|
||||
} else {
|
||||
// If not, just forward input tensor to output tensor.
|
||||
|
@ -146,6 +146,7 @@ filegroup(
|
||||
"mirror_pad_mode.h",
|
||||
"mkl_types.h",
|
||||
"mkl_util.h",
|
||||
"mkl_threadpool.h",
|
||||
"overflow.h",
|
||||
"padding.h",
|
||||
"permutation_input_iterator.h",
|
||||
@ -274,6 +275,7 @@ filegroup(
|
||||
name = "mkl_util_hdrs",
|
||||
srcs = [
|
||||
"mkl_util.h",
|
||||
"mkl_threadpool.h",
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
||||
|
149
tensorflow/core/util/mkl_threadpool.h
Normal file
149
tensorflow/core/util/mkl_threadpool.h
Normal file
@ -0,0 +1,149 @@
|
||||
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
|
||||
#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#define EIGEN_USE_THREADS
|
||||
#include "tensorflow/core/platform/threadpool.h"
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
using dnnl::threadpool_iface;
|
||||
using dnnl::stream_attr;
|
||||
|
||||
namespace tensorflow {
|
||||
// balance211 function tries to divide n jobs equally among 'team' threads.
|
||||
// This is the same as DNNL load balancer.
|
||||
template <typename T, typename U>
|
||||
inline void balance211(T n, U team, U tid, T& n_start, T& n_end) {
|
||||
T& n_my = n_end;
|
||||
if (team <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
// team = T1 + T2
|
||||
// n = T1*n1 + T2*n2 (n1 - n2 = 1)
|
||||
T n1 = (n + (T)team - 1) / team;
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * (T)team;
|
||||
n_my = (T)tid < T1 ? n1 : n2;
|
||||
n_start = (T)tid <= T1 ? tid * n1 : T1 * n1 + ((T)tid - T1) * n2;
|
||||
}
|
||||
|
||||
n_end += n_start;
|
||||
}
|
||||
|
||||
struct MklDnnThreadPool : public dnnl::threadpool_iface {
|
||||
MklDnnThreadPool() = default;
|
||||
|
||||
MklDnnThreadPool(OpKernelContext* ctx)
|
||||
: eigen_interface_(ctx->device()
|
||||
->tensorflow_cpu_worker_threads()
|
||||
->workers->AsEigenThreadPool())
|
||||
#if DNNL_PRINT_STATS
|
||||
,
|
||||
jobs_per_thread(eigen_interface_->NumThreads())
|
||||
#endif
|
||||
{
|
||||
}
|
||||
virtual int get_num_threads() override {
|
||||
return eigen_interface_->NumThreads();
|
||||
}
|
||||
virtual bool get_in_parallel() override {
|
||||
return (eigen_interface_->CurrentThreadId() != -1) ? true : false;
|
||||
}
|
||||
virtual uint64_t get_flags() override { return ASYNCHRONOUS; }
|
||||
virtual void parallel_for(int n,
|
||||
const std::function<void(int, int)>& fn) override {
|
||||
// Should never happen (handled by DNNL)
|
||||
if (n == 0) return;
|
||||
|
||||
// Should never happen (handled by DNNL)
|
||||
if (n == 1) {
|
||||
fn(0, 1);
|
||||
return;
|
||||
}
|
||||
|
||||
int nthr = get_num_threads();
|
||||
int njobs = std::min(n, nthr);
|
||||
for (int i = 0; i < njobs; i++) {
|
||||
eigen_interface_->ScheduleWithHint(
|
||||
[i, n, njobs, fn]() {
|
||||
int start, end;
|
||||
balance211(n, njobs, i, start, end);
|
||||
for (int j = start; j < end; j++) fn(j, n);
|
||||
#if DNNL_PRINT_STATS
|
||||
jobs_per_thread[eigen_interface_->CurrentThreadId()]++;
|
||||
#endif
|
||||
},
|
||||
i, i + 1);
|
||||
}
|
||||
}
|
||||
#if DNNL_PRINT_STATS
|
||||
void print_thread_usage_stats() {
|
||||
for (int i = 0; i < jobs_per_thread_.size(); i++)
|
||||
std::cout << " Thread" << i << "," << jobs_per_thread[i] << std::endl;
|
||||
}
|
||||
#endif
|
||||
~MklDnnThreadPool() {}
|
||||
|
||||
private:
|
||||
Eigen::ThreadPoolInterface* eigen_interface_ = nullptr;
|
||||
std::vector<int> jobs_per_thread_;
|
||||
};
|
||||
|
||||
class MklDnnThreadPoolWrapper {
|
||||
public:
|
||||
static MklDnnThreadPoolWrapper& GetInstance() {
|
||||
static MklDnnThreadPoolWrapper instance_;
|
||||
return instance_;
|
||||
}
|
||||
MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) {
|
||||
if (threadpool_map_.empty() ||
|
||||
threadpool_map_.find(ctx->device()) == threadpool_map_.end()) {
|
||||
auto tp_iface = new MklDnnThreadPool(ctx);
|
||||
threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface));
|
||||
return tp_iface;
|
||||
} else {
|
||||
auto entry = threadpool_map_.find(ctx->device());
|
||||
return entry->second;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<DeviceBase*, MklDnnThreadPool*> threadpool_map_;
|
||||
MklDnnThreadPoolWrapper() {}
|
||||
MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete;
|
||||
MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete;
|
||||
~MklDnnThreadPoolWrapper() {
|
||||
for (auto& tp : threadpool_map_) {
|
||||
delete tp.second;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
#endif // INTEL_MKL
|
||||
#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_
|
@ -36,6 +36,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/core/util/mkl_threadpool.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
@ -48,7 +49,6 @@ using mkldnn::padding_kind;
|
||||
using mkldnn::primitive;
|
||||
using mkldnn::reorder;
|
||||
using mkldnn::stream;
|
||||
|
||||
using CPUDevice = Eigen::ThreadPoolDevice;
|
||||
using MemoryArgsMap = std::unordered_map<int, memory>;
|
||||
using ReorderPd = mkldnn::reorder::primitive_desc;
|
||||
@ -232,6 +232,27 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) {
|
||||
return true;
|
||||
}
|
||||
|
||||
inline mkldnn::stream* CreateStream(OpKernelContext* ctx,
|
||||
const engine& engine) {
|
||||
#ifdef ENABLE_MKLDNN_THREADPOOL
|
||||
stream_attr tp_stream_attr(ENGINE_CPU);
|
||||
if (ctx != nullptr) {
|
||||
auto eigen_tp =
|
||||
MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx);
|
||||
tp_stream_attr.set_threadpool(eigen_tp);
|
||||
stream* tp_stream =
|
||||
new stream(engine, stream::flags::default_flags, tp_stream_attr);
|
||||
return tp_stream;
|
||||
} else {
|
||||
stream* tp_stream = new CPU_STREAM(engine);
|
||||
return tp_stream;
|
||||
}
|
||||
#else
|
||||
stream* tp_stream = new CPU_STREAM(engine);
|
||||
return tp_stream;
|
||||
#endif // ENABLE_MKLDNN_THREADPOOL
|
||||
}
|
||||
|
||||
class MklDnnShape {
|
||||
private:
|
||||
typedef struct {
|
||||
@ -679,20 +700,21 @@ class MklDnnData;
|
||||
// TODO merge with the execute_primitives.
|
||||
inline void ExecutePrimitive(const std::vector<primitive>& net,
|
||||
const std::vector<MemoryArgsMap>* net_args,
|
||||
const engine& cpu_engine) {
|
||||
const engine& cpu_engine,
|
||||
OpKernelContext* context = nullptr) {
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK(net_args);
|
||||
DCHECK_EQ(net.size(), net_args->size());
|
||||
stream cpu_stream(cpu_engine);
|
||||
stream* cpu_stream = CreateStream(context, cpu_engine);
|
||||
for (size_t i = 0; i < net.size(); ++i) {
|
||||
net.at(i).execute(cpu_stream, net_args->at(i));
|
||||
net.at(i).execute(*cpu_stream, net_args->at(i));
|
||||
}
|
||||
cpu_stream.wait();
|
||||
cpu_stream->wait();
|
||||
delete cpu_stream;
|
||||
#else
|
||||
stream(stream::kind::eager_nostore).submit(net).wait();
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline Status ConvertMklToTF(OpKernelContext* context,
|
||||
const Tensor& input_mkl_tensor,
|
||||
@ -731,7 +753,7 @@ inline Status ConvertMklToTF(OpKernelContext* context,
|
||||
return Status(error::Code::INTERNAL,
|
||||
"ConvertMklToTF(): Failed to create reorder for input");
|
||||
}
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine);
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context);
|
||||
} else {
|
||||
// If not, just forward input tensor to output tensor.
|
||||
bool status =
|
||||
@ -744,9 +766,9 @@ inline Status ConvertMklToTF(OpKernelContext* context,
|
||||
}
|
||||
return Status::OK();
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||
string(e.message) + ", in file " + string(__FILE__) +
|
||||
":" + std::to_string(__LINE__);
|
||||
LOG(FATAL) << "Operation received an exception: " << error_msg;
|
||||
}
|
||||
}
|
||||
@ -1273,8 +1295,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
|
||||
} catch (mkldnn::error& e) {
|
||||
return Status(error::Code::INTERNAL,
|
||||
tensorflow::strings::StrCat(
|
||||
"Failed to create blocked memory descriptor.",
|
||||
"Status: ", e.status, ", message: ", e.message));
|
||||
"Failed to create blocked memory descriptor.", "Status: ",
|
||||
e.status, ", message: ", e.message));
|
||||
}
|
||||
#else
|
||||
// We have to construct memory descriptor in a C style. This is not at all
|
||||
@ -1301,8 +1323,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim,
|
||||
|
||||
inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
|
||||
const memory& src_mem,
|
||||
const memory& dst_mem,
|
||||
const engine& engine) {
|
||||
const memory& dst_mem, const engine& engine,
|
||||
OpKernelContext* ctx = nullptr) {
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
net.push_back(mkldnn::reorder(reorder_desc));
|
||||
@ -1311,7 +1333,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc,
|
||||
#else
|
||||
net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, engine);
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, engine, ctx);
|
||||
}
|
||||
|
||||
class MklReorderPrimitive;
|
||||
@ -1629,22 +1651,26 @@ class MklDnnData {
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
inline bool CheckReorderToOpMem(const memory::desc& op_md,
|
||||
const engine& engine) {
|
||||
const engine& engine,
|
||||
OpKernelContext* context = nullptr) {
|
||||
DCHECK(user_memory_);
|
||||
if (IsReorderNeeded(op_md)) {
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
// primitive reuse don't allow two same reorder prim in
|
||||
// one stream, so submit it immediately
|
||||
reorder_memory_ = new memory(op_md, engine);
|
||||
std::vector<primitive> net;
|
||||
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
|
||||
std::shared_ptr<stream> cpu_stream;
|
||||
cpu_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||
std::vector<primitive> net;
|
||||
net.push_back(*(prim->GetPrimitive()));
|
||||
std::vector<MemoryArgsMap> net_args;
|
||||
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
|
||||
{MKLDNN_ARG_TO, *reorder_memory_}});
|
||||
execute_primitives(net, prim->GetStream(), net_args);
|
||||
execute_primitives(net, cpu_stream, net_args);
|
||||
#else
|
||||
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) {
|
||||
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
|
||||
OpKernelContext* ctx = nullptr) {
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
if (IsReorderNeeded(op_pd)) {
|
||||
reorder_memory_ = new memory(op_pd);
|
||||
@ -1708,7 +1734,8 @@ class MklDnnData {
|
||||
/// TODO(bhavanis): Need to use reorder cache here for better performance.
|
||||
inline bool CheckReorderToOpMem(const memory::desc& op_md,
|
||||
void* reorder_data_handle,
|
||||
const engine& engine) {
|
||||
const engine& engine,
|
||||
OpKernelContext* context = nullptr) {
|
||||
DCHECK(reorder_data_handle);
|
||||
DCHECK(user_memory_);
|
||||
if (IsReorderNeeded(op_md)) {
|
||||
@ -1716,16 +1743,19 @@ class MklDnnData {
|
||||
// primitive reuse don't allow two same reorder prim in
|
||||
// one stream, so submit it immediately
|
||||
reorder_memory_ = new memory(op_md, engine, reorder_data_handle);
|
||||
std::vector<primitive> net;
|
||||
auto* prim = FindOrCreateReorder<T>(user_memory_, reorder_memory_);
|
||||
std::shared_ptr<stream> cpu_stream;
|
||||
cpu_stream.reset(CreateStream(context, prim->GetEngine()));
|
||||
std::vector<primitive> net;
|
||||
net.push_back(*(prim->GetPrimitive()));
|
||||
std::vector<MemoryArgsMap> net_args;
|
||||
net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_},
|
||||
{MKLDNN_ARG_TO, *reorder_memory_}});
|
||||
execute_primitives(net, prim->GetStream(), net_args);
|
||||
execute_primitives(net, cpu_stream, net_args);
|
||||
#else
|
||||
inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
|
||||
void* reorder_data_handle) {
|
||||
void* reorder_data_handle,
|
||||
OpKernelContext* context = nullptr) {
|
||||
CHECK_NOTNULL(reorder_data_handle);
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
if (IsReorderNeeded(op_pd)) {
|
||||
@ -1778,13 +1808,14 @@ class MklDnnData {
|
||||
/// remove
|
||||
/// slow path in the future
|
||||
inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd,
|
||||
Tensor* reorder_tensor) {
|
||||
Tensor* reorder_tensor,
|
||||
OpKernelContext* ctx = nullptr) {
|
||||
DCHECK(reorder_tensor);
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor),
|
||||
*cpu_engine_);
|
||||
*cpu_engine_, ctx);
|
||||
#else
|
||||
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor));
|
||||
return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), ctx);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
}
|
||||
|
||||
@ -1843,7 +1874,7 @@ class MklDnnData {
|
||||
/// TODO: this is a faster path with reorder primitive cache compared with
|
||||
/// InsertReorderToUserMem(net, net_args), will remove
|
||||
/// slow path in the future
|
||||
inline void InsertReorderToUserMem() {
|
||||
inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) {
|
||||
DCHECK(user_memory_);
|
||||
DCHECK(reorder_memory_);
|
||||
DCHECK(cpu_engine_);
|
||||
@ -1857,8 +1888,8 @@ class MklDnnData {
|
||||
net_args.push_back(
|
||||
{{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}});
|
||||
std::shared_ptr<stream> cpu_stream;
|
||||
cpu_stream.reset(new stream(*cpu_engine_));
|
||||
execute_primitives(net, prim->GetStream(), net_args);
|
||||
cpu_stream.reset(CreateStream(ctx, prim->GetEngine()));
|
||||
execute_primitives(net, cpu_stream, net_args);
|
||||
#else
|
||||
net.push_back(FindOrCreateReorder<T>(reorder_memory_, user_memory_));
|
||||
ExecutePrimitive(net, NET_ARGS_PTR, *cpu_engine_);
|
||||
@ -1870,9 +1901,12 @@ class MklDnnData {
|
||||
class MklPrimitive {
|
||||
public:
|
||||
virtual ~MklPrimitive() {}
|
||||
|
||||
MklPrimitive() {}
|
||||
MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; }
|
||||
// Dummy data which MKL DNN never operates on
|
||||
unsigned char* DummyData = nullptr;
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
const engine& GetEngine() { return cpu_engine_; }
|
||||
};
|
||||
|
||||
const mkldnn::memory::dims NONE_DIMS = {};
|
||||
@ -2058,7 +2092,8 @@ class FactoryKeyCreator {
|
||||
|
||||
class MklReorderPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklReorderPrimitive(const memory* from, const memory* to) {
|
||||
explicit MklReorderPrimitive(const memory* from, const memory* to)
|
||||
: MklPrimitive(engine(ENGINE_CPU, 0)) {
|
||||
Setup(from, to);
|
||||
}
|
||||
~MklReorderPrimitive() {}
|
||||
@ -2081,7 +2116,6 @@ class MklReorderPrimitive : public MklPrimitive {
|
||||
: src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {}
|
||||
} context_;
|
||||
|
||||
engine cpu_engine_ = engine(ENGINE_CPU, 0);
|
||||
std::shared_ptr<mkldnn::stream> stream_;
|
||||
|
||||
void Setup(const memory* from, const memory* to) {
|
||||
|
Loading…
Reference in New Issue
Block a user