diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc index 2c7f66bd5bc..8f87873a5bd 100644 --- a/tensorflow/core/common_runtime/process_util.cc +++ b/tensorflow/core/common_runtime/process_util.cc @@ -101,8 +101,7 @@ int32 NumIntraOpThreadsFromEnvironment() { const char* val = std::getenv("TF_NUM_INTRAOP_THREADS"); return (val && strings::safe_strto32(val, &num)) ? num : 0; } - -#ifdef INTEL_MKL +#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) int32 OMPThreadsFromEnvironment() { // 1) std::getenv is thread-safe (as long as no other function modifies the // host env) from C++11 onward. 2) Most of TF code (except tests and @@ -122,14 +121,14 @@ int32 DefaultNumIntraOpThreads() { // Default to the maximum parallelism for the current process. return port::MaxParallelism(); } -#endif // INTEL_MKL +#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { const int32 inter_op = options.config.inter_op_parallelism_threads(); if (inter_op > 0) return inter_op; const int32 env_inter_op = GetEnvNumInterOpThreads(); if (env_inter_op > 0) return env_inter_op; -#ifdef INTEL_MKL +#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) if (!DisableMKL()) { // MKL library executes ops in parallel using OMP threads. // Setting inter_op conservatively to avoid thread oversubscription that @@ -150,7 +149,7 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { << ". Tune using inter_op_parallelism_threads for best performance."; return mkl_inter_op; } -#endif // INTEL_MKL +#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) return DefaultNumInterOpThreads(); } diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 68fcc9a079a..44fa5bf2d3a 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -50,7 +50,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, name, DEVICE_CPU, memory_limit, locality)), allocator_(allocator), scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) { -#ifdef INTEL_MKL +#if !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) // Early return when MKL is disabled if (DisableMKL()) return; #ifdef _OPENMP @@ -69,7 +69,7 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, } } #endif // _OPENMP -#endif // INTEL_MKL +#endif // !defined(ENABLE_MKLDNN_THREADPOOL) && defined(INTEL_MKL) } ThreadPoolDevice::~ThreadPoolDevice() {} diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 52bb1e404e1..59de3229211 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -51,12 +51,10 @@ limitations under the License. using mkldnn::convolution_forward; using mkldnn::prop_kind; using mkldnn::stream; - using ConvFwdPd = mkldnn::convolution_forward::primitive_desc; using ReorderPd = mkldnn::reorder::primitive_desc; namespace tensorflow { - // This structure aggregates multiple inputs to Conv2DFwd* methods. struct MklConvFwdParams { memory::dims src_dims; @@ -96,14 +94,12 @@ template class MklConvFwdPrimitive : public MklPrimitive { public: explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) - : cpu_engine_(ENGINE_CPU, 0) { - context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); + : MklPrimitive(engine(ENGINE_CPU, 0)) { // Create convolution primitive if (context_.conv_fwd == nullptr) { Setup(convFwdDims); } } - ~MklConvFwdPrimitive() {} // Convolution forward execute with bias @@ -112,7 +108,8 @@ class MklConvFwdPrimitive : public MklPrimitive { // bias_data: input data buffer of bias // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Tbias* bias_data, const Toutput* dst_data) { + const Tbias* bias_data, const Toutput* dst_data, + std::shared_ptr fwd_stream) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.filter_mem->set_data_handle( @@ -127,11 +124,11 @@ class MklConvFwdPrimitive : public MklPrimitive { DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { - context_.fwd_primitives.at(i).execute(*context_.fwd_stream, + context_.fwd_primitives.at(i).execute(*fwd_stream, context_.fwd_primitives_args.at(i)); } #else - context_.fwd_stream->submit(context_.fwd_primitives); + fwd_stream->submit(context_.fwd_primitives); #endif // ENABLE_MKLDNN_V1 // After execution, set data handle back @@ -148,8 +145,8 @@ class MklConvFwdPrimitive : public MklPrimitive { // filter_data: input data buffer of filter (weights) // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, - const Toutput* dst_data) { - Execute(src_data, filter_data, nullptr, dst_data); + const Toutput* dst_data, std::shared_ptr fwd_stream) { + Execute(src_data, filter_data, nullptr, dst_data, fwd_stream); } #ifndef ENABLE_MKLDNN_V1 @@ -191,7 +188,6 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr fwd_pd; std::shared_ptr conv_fwd; - std::shared_ptr fwd_stream; std::vector fwd_primitives; #ifdef ENABLE_MKLDNN_V1 @@ -213,8 +209,7 @@ class MklConvFwdPrimitive : public MklPrimitive { filter_md(nullptr), bias_md(nullptr), fwd_pd(nullptr), - conv_fwd(nullptr), - fwd_stream(nullptr) { + conv_fwd(nullptr) { } }; @@ -346,7 +341,6 @@ class MklConvFwdPrimitive : public MklPrimitive { } struct ConvFwdContext context_; - engine cpu_engine_; }; // TODO(nhasabni): We should not require passing a type to MklPrimitiveFactory. @@ -678,11 +672,9 @@ class MklConvOp : public OpKernel { // TODO(mdfaijul): Extend the basic parameters for data types and fusions this->ExtendConvFwdParams(context, convFwdDims); - conv_fwd = MklConvFwdPrimitiveFactory::Get( convFwdDims, do_not_cache); - // Allocate output tensors `output_tensor` and `filter_out_tensor` MklDnnShape output_mkl_shape; std::shared_ptr conv_fwd_pd = conv_fwd->GetPrimitiveDesc(); @@ -703,8 +695,10 @@ class MklConvOp : public OpKernel { Tinput* src_data = nullptr; if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) { src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_)); + src.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), + cpu_engine_), + context); src_data = static_cast(src.GetOpMem().get_data_handle()); } else { src_data = static_cast( @@ -735,13 +729,16 @@ class MklConvOp : public OpKernel { if (!is_filter_cached) { filter.SetUsrMem(filter_md, &filter_tensor); if (filter_out_tensor == nullptr) { - filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_)); + filter.CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), + cpu_engine_), + context); } else { filter.CheckReorderToOpMem( GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor), - cpu_engine_)); + cpu_engine_), + context); } filter_data = static_cast(filter.GetOpMem().get_data_handle()); @@ -752,20 +749,23 @@ class MklConvOp : public OpKernel { } // Execute convolution + std::shared_ptr fwd_cpu_stream; + fwd_cpu_stream.reset(CreateStream(context, conv_fwd->GetEngine())); if (fuse_biasadd_) { const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias); Tbias* bias_data = this->GetBiasHandle(context, conv_fwd_pd, bias_tensor); - conv_fwd->Execute(src_data, filter_data, bias_data, dst_data); + conv_fwd->Execute(src_data, filter_data, bias_data, dst_data, + fwd_cpu_stream); } else { if (!eager_mode) { - conv_fwd->Execute(src_data, filter_data, dst_data); + conv_fwd->Execute(src_data, filter_data, dst_data, fwd_cpu_stream); } else { // In eager mode we first write the output to temporary // buffer in MKL format. Then we convert the data to TF format. Ttemp_output* tmp_data = reinterpret_cast( tmp_tensor.flat().data()); - conv_fwd->Execute(src_data, filter_data, tmp_data); + conv_fwd->Execute(src_data, filter_data, tmp_data, fwd_cpu_stream); // Now we need to convert the output to TF format. auto output_tf_md = output_mkl_shape.GetTfLayout(); @@ -780,12 +780,13 @@ class MklConvOp : public OpKernel { memory* dst_data_mem = new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, dst_data); CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem, - cpu_engine_); + cpu_engine_, context); } } // Delete primitive since it is not cached. if (do_not_cache) delete conv_fwd; + } catch (mkldnn::error& e) { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", @@ -970,8 +971,9 @@ class MklConvOp : public OpKernel { new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf)); auto reorder_desc = REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_); + CreateAndExecuteReorder(reorder_desc, *fuse_add_src_, *fuse_add_dst_, - this->cpu_engine_); + this->cpu_engine_, context); } } else { AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, @@ -1097,6 +1099,7 @@ class MklConvOp : public OpKernel { filter_tf_shape, filter_mkl_shape); } + // TODO(intel-mkl): This function does not seem to be called. Remove it. // Prepare and execute net - checks for input and output reorders. void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc, MklDnnData* src, @@ -1185,7 +1188,7 @@ class MklConvOp : public OpKernel { // Otherwise, cache filter filter.SetUsrMem(filter_md, &filter_tensor); filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(), - this->cpu_engine_); + this->cpu_engine_, context); filter_data = static_cast(filter.GetOpMem().get_data_handle()); Tensor* filter_tensor_ptr = nullptr; @@ -1251,9 +1254,9 @@ class MklConvOp : public OpKernel { const Tensor& cached_filter_md = *cached_filter_md_ptensor_.AccessTensor(context); - // Check if the memory descriptor of the cached weights is same as - // filter_md. If so, we can use the cached weights; otherwise - // return nullptr. +// Check if the memory descriptor of the cached weights is same as +// filter_md. If so, we can use the cached weights; otherwise +// return nullptr. #ifdef ENABLE_MKLDNN_V1 if (filter_md == *static_cast(cached_filter_md.data())) { #else @@ -1652,7 +1655,7 @@ class MklQuantizedConv2DOp input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_, bias_attr); CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, - this->cpu_engine_); + this->cpu_engine_, context); Tbias* bias_data = reinterpret_cast(scaled_bias_->get_data_handle()); @@ -1908,7 +1911,8 @@ class MklQuantizedConv2DSumReluOp auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_, reorder_attr); - CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_); + CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_, + context); } std::shared_ptr summand_; diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc index e69fddd327a..f7866cbcea6 100644 --- a/tensorflow/core/kernels/mkl_input_conversion_op.cc +++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc @@ -165,7 +165,7 @@ class MklInputConversionOp : public OpKernel { input1_md, tensor_out, net, net_args, cpu_engine)), errors::Internal( "MklInputConversionOp: Failed to create reorder for input0")); - ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); + ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context); // Input1 will be passed through ForwardMklTensorInToOut(context, kInputIndex_1, kInputIndex_1); return; @@ -273,7 +273,7 @@ class MklInputConversionOp : public OpKernel { errors::Internal("MklInputConversionOp: Failed to forward " "input tensor to output")); } else { - ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); + ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context); } // -- The tensor in MKL format passes through -- diff --git a/tensorflow/core/kernels/mkl_reshape_op.cc b/tensorflow/core/kernels/mkl_reshape_op.cc index c31e67b84cb..bda3fad38cf 100644 --- a/tensorflow/core/kernels/mkl_reshape_op.cc +++ b/tensorflow/core/kernels/mkl_reshape_op.cc @@ -172,7 +172,8 @@ class MklReshapeOp : public OpKernel { // shape_from != shape_to), then we just copy input tensor to // output tensor with target shape (we cannot forward Mkl layout // in such case because shape has changed.) - if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor)) { + if (dnn_data_input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor, + context)) { } else { OP_REQUIRES(context, output_tensor->CopyFrom(input_tensor, shape_to), diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h index 2e489617a40..f7aa4d2bebf 100644 --- a/tensorflow/core/kernels/mkl_tfconv_op.h +++ b/tensorflow/core/kernels/mkl_tfconv_op.h @@ -111,7 +111,8 @@ class MklToTfOp : public OpKernel { if (input.IsReorderNeeded(OUTPUT_TF_MD)) { // Insert reorder between MKL layout and TensorFlow layout OP_REQUIRES( - context, input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor), + context, + input.CheckReorderToOpMem(OUTPUT_TF_MD, output_tensor, context), errors::Internal("MklToTfOp: Failed to create input reorder")); } else { // If not, just forward input tensor to output tensor. diff --git a/tensorflow/core/util/BUILD b/tensorflow/core/util/BUILD index b8c2b3b4f59..de2dce9c0c2 100644 --- a/tensorflow/core/util/BUILD +++ b/tensorflow/core/util/BUILD @@ -144,6 +144,7 @@ filegroup( "matmul_autotune.h", "matmul_bcast.h", "mirror_pad_mode.h", + "mkl_threadpool.h", "mkl_types.h", "mkl_util.h", "overflow.h", @@ -273,6 +274,7 @@ filegroup( filegroup( name = "mkl_util_hdrs", srcs = [ + "mkl_threadpool.h", "mkl_util.h", ], visibility = ["//tensorflow/core:__pkg__"], diff --git a/tensorflow/core/util/mkl_threadpool.h b/tensorflow/core/util/mkl_threadpool.h new file mode 100644 index 00000000000..8c9db0a1940 --- /dev/null +++ b/tensorflow/core/util/mkl_threadpool.h @@ -0,0 +1,138 @@ + +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ +#define TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ +#ifdef INTEL_MKL + +#include +#include +#include +#include +#include +#include + +#include "mkldnn.hpp" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/platform/threadpool.h" +#define EIGEN_USE_THREADS +#ifdef ENABLE_MKLDNN_THREADPOOL +using dnnl::stream_attr; +using dnnl::threadpool_iface; + +namespace tensorflow { + +// Divide 'n' units of work equally among 'teams' threads. If 'n' is not +// divisible by 'teams' and has a remainder 'r', the first 'r' teams have one +// unit of work more than the rest. Returns the range of work that belongs to +// the team 'tid'. +// Parameters +// n Total number of jobs. +// team Number of workers. +// tid Current thread_id. +// n_start start of range operated by the thread. +// n_end end of the range operated by the thread. + +template +inline void balance211(T n, U team, U tid, T* n_start, T* n_end) { + if (team <= 1 || n == 0) { + *n_start = 0; + *n_end = n; + return; + } + T min_per_team = n / team; + T remainder = n - min_per_team * team; // i.e., n % teams. + *n_start = tid * min_per_team + std::min(tid, remainder); + *n_end = *n_start + min_per_team + (tid < remainder); +} + +struct MklDnnThreadPool : public dnnl::threadpool_iface { + MklDnnThreadPool() = default; + + MklDnnThreadPool(OpKernelContext* ctx) + : eigen_interface_(ctx->device() + ->tensorflow_cpu_worker_threads() + ->workers->AsEigenThreadPool()) {} + virtual int get_num_threads() override { + return eigen_interface_->NumThreads(); + } + virtual bool get_in_parallel() override { + return (eigen_interface_->CurrentThreadId() != -1) ? true : false; + } + virtual uint64_t get_flags() override { return ASYNCHRONOUS; } + virtual void parallel_for(int n, + const std::function& fn) override { + // Should never happen (handled by DNNL) + if (n == 0) return; + + // Should never happen (handled by DNNL) + if (n == 1) { + fn(0, 1); + return; + } + + int nthr = get_num_threads(); + int njobs = std::min(n, nthr); + for (int i = 0; i < njobs; i++) { + eigen_interface_->ScheduleWithHint( + [i, n, njobs, fn]() { + int start, end; + balance211(n, njobs, i, &start, &end); + for (int j = start; j < end; j++) fn(j, n); + }, + i, i + 1); + } + } + ~MklDnnThreadPool() {} + + private: + Eigen::ThreadPoolInterface* eigen_interface_ = nullptr; +}; + +class MklDnnThreadPoolWrapper { + public: + static MklDnnThreadPoolWrapper& GetInstance() { + static MklDnnThreadPoolWrapper instance_; + return instance_; + } + MklDnnThreadPool* CreateThreadPoolPtr(OpKernelContext* ctx) { + if (threadpool_map_.empty() || + threadpool_map_.find(ctx->device()) == threadpool_map_.end()) { + auto tp_iface = new MklDnnThreadPool(ctx); + threadpool_map_.emplace(std::make_pair(ctx->device(), tp_iface)); + return tp_iface; + } else { + auto entry = threadpool_map_.find(ctx->device()); + return entry->second; + } + } + + private: + std::unordered_map threadpool_map_; + MklDnnThreadPoolWrapper() {} + MklDnnThreadPoolWrapper(const MklDnnThreadPoolWrapper&) = delete; + MklDnnThreadPoolWrapper& operator=(const MklDnnThreadPoolWrapper&) = delete; + ~MklDnnThreadPoolWrapper() { + for (auto& tp : threadpool_map_) { + delete tp.second; + } + } +}; + +} // namespace tensorflow +#endif // ENABLE_MKLDNN_THREADPOOL +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_UTIL_MKL_THREADPOOL_H_ diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index e0a399f2d6c..7f6272b09c1 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -36,6 +36,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/env_var.h" +#include "tensorflow/core/util/mkl_threadpool.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -48,7 +49,6 @@ using mkldnn::padding_kind; using mkldnn::primitive; using mkldnn::reorder; using mkldnn::stream; - using CPUDevice = Eigen::ThreadPoolDevice; using MemoryArgsMap = std::unordered_map; using ReorderPd = mkldnn::reorder::primitive_desc; @@ -232,6 +232,27 @@ inline bool array_cmp(const T* a1, const T* a2, size_t size) { return true; } +inline mkldnn::stream* CreateStream(OpKernelContext* ctx, + const engine& engine) { +#ifdef ENABLE_MKLDNN_THREADPOOL + stream_attr tp_stream_attr(ENGINE_CPU); + if (ctx != nullptr) { + auto eigen_tp = + MklDnnThreadPoolWrapper::GetInstance().CreateThreadPoolPtr(ctx); + tp_stream_attr.set_threadpool(eigen_tp); + stream* tp_stream = + new stream(engine, stream::flags::default_flags, tp_stream_attr); + return tp_stream; + } else { + stream* tp_stream = new CPU_STREAM(engine); + return tp_stream; + } +#else + stream* tp_stream = new CPU_STREAM(engine); + return tp_stream; +#endif // ENABLE_MKLDNN_THREADPOOL +} + class MklDnnShape { private: typedef struct { @@ -679,20 +700,21 @@ class MklDnnData; // TODO merge with the execute_primitives. inline void ExecutePrimitive(const std::vector& net, const std::vector* net_args, - const engine& cpu_engine) { + const engine& cpu_engine, + OpKernelContext* context = nullptr) { #ifdef ENABLE_MKLDNN_V1 DCHECK(net_args); DCHECK_EQ(net.size(), net_args->size()); - stream cpu_stream(cpu_engine); + stream* cpu_stream = CreateStream(context, cpu_engine); for (size_t i = 0; i < net.size(); ++i) { - net.at(i).execute(cpu_stream, net_args->at(i)); + net.at(i).execute(*cpu_stream, net_args->at(i)); } - cpu_stream.wait(); + cpu_stream->wait(); + delete cpu_stream; #else stream(stream::kind::eager_nostore).submit(net).wait(); #endif // ENABLE_MKLDNN_V1 } - template inline Status ConvertMklToTF(OpKernelContext* context, const Tensor& input_mkl_tensor, @@ -731,7 +753,7 @@ inline Status ConvertMklToTF(OpKernelContext* context, return Status(error::Code::INTERNAL, "ConvertMklToTF(): Failed to create reorder for input"); } - ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine); + ExecutePrimitive(net, NET_ARGS_PTR, cpu_engine, context); } else { // If not, just forward input tensor to output tensor. bool status = @@ -1301,8 +1323,8 @@ inline Status CreateBlockedMemDescHelper(const memory::dims& dim, inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, const memory& src_mem, - const memory& dst_mem, - const engine& engine) { + const memory& dst_mem, const engine& engine, + OpKernelContext* ctx = nullptr) { std::vector net; #ifdef ENABLE_MKLDNN_V1 net.push_back(mkldnn::reorder(reorder_desc)); @@ -1311,7 +1333,7 @@ inline void CreateAndExecuteReorder(const ReorderPd& reorder_desc, #else net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem)); #endif // ENABLE_MKLDNN_V1 - ExecutePrimitive(net, NET_ARGS_PTR, engine); + ExecutePrimitive(net, NET_ARGS_PTR, engine, ctx); } class MklReorderPrimitive; @@ -1629,22 +1651,26 @@ class MklDnnData { #ifdef ENABLE_MKLDNN_V1 inline bool CheckReorderToOpMem(const memory::desc& op_md, - const engine& engine) { + const engine& engine, + OpKernelContext* context = nullptr) { DCHECK(user_memory_); if (IsReorderNeeded(op_md)) { // TODO(nhasabni): can we remove dynamic memory allocation? // primitive reuse don't allow two same reorder prim in // one stream, so submit it immediately reorder_memory_ = new memory(op_md, engine); - std::vector net; auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); + std::shared_ptr cpu_stream; + cpu_stream.reset(CreateStream(context, prim->GetEngine())); + std::vector net; net.push_back(*(prim->GetPrimitive())); std::vector net_args; net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_}, {MKLDNN_ARG_TO, *reorder_memory_}}); - execute_primitives(net, prim->GetStream(), net_args); + execute_primitives(net, cpu_stream, net_args); #else - inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd) { + inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd, + OpKernelContext* ctx = nullptr) { CHECK_NOTNULL(user_memory_); if (IsReorderNeeded(op_pd)) { reorder_memory_ = new memory(op_pd); @@ -1708,7 +1734,8 @@ class MklDnnData { /// TODO(bhavanis): Need to use reorder cache here for better performance. inline bool CheckReorderToOpMem(const memory::desc& op_md, void* reorder_data_handle, - const engine& engine) { + const engine& engine, + OpKernelContext* context = nullptr) { DCHECK(reorder_data_handle); DCHECK(user_memory_); if (IsReorderNeeded(op_md)) { @@ -1716,16 +1743,19 @@ class MklDnnData { // primitive reuse don't allow two same reorder prim in // one stream, so submit it immediately reorder_memory_ = new memory(op_md, engine, reorder_data_handle); - std::vector net; auto* prim = FindOrCreateReorder(user_memory_, reorder_memory_); + std::shared_ptr cpu_stream; + cpu_stream.reset(CreateStream(context, prim->GetEngine())); + std::vector net; net.push_back(*(prim->GetPrimitive())); std::vector net_args; net_args.push_back({{MKLDNN_ARG_FROM, *user_memory_}, {MKLDNN_ARG_TO, *reorder_memory_}}); - execute_primitives(net, prim->GetStream(), net_args); + execute_primitives(net, cpu_stream, net_args); #else inline bool CheckReorderToOpMem(const memory::primitive_desc& op_pd, - void* reorder_data_handle) { + void* reorder_data_handle, + OpKernelContext* context = nullptr) { CHECK_NOTNULL(reorder_data_handle); CHECK_NOTNULL(user_memory_); if (IsReorderNeeded(op_pd)) { @@ -1778,13 +1808,14 @@ class MklDnnData { /// remove /// slow path in the future inline bool CheckReorderToOpMem(const MEMORY_PRIMITIVE_DESC& op_pd, - Tensor* reorder_tensor) { + Tensor* reorder_tensor, + OpKernelContext* ctx = nullptr) { DCHECK(reorder_tensor); #ifdef ENABLE_MKLDNN_V1 return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), - *cpu_engine_); + *cpu_engine_, ctx); #else - return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor)); + return CheckReorderToOpMem(op_pd, GetTensorBuffer(reorder_tensor), ctx); #endif // ENABLE_MKLDNN_V1 } @@ -1843,7 +1874,7 @@ class MklDnnData { /// TODO: this is a faster path with reorder primitive cache compared with /// InsertReorderToUserMem(net, net_args), will remove /// slow path in the future - inline void InsertReorderToUserMem() { + inline void InsertReorderToUserMem(OpKernelContext* ctx = nullptr) { DCHECK(user_memory_); DCHECK(reorder_memory_); DCHECK(cpu_engine_); @@ -1857,8 +1888,8 @@ class MklDnnData { net_args.push_back( {{MKLDNN_ARG_FROM, *reorder_memory_}, {MKLDNN_ARG_TO, *user_memory_}}); std::shared_ptr cpu_stream; - cpu_stream.reset(new stream(*cpu_engine_)); - execute_primitives(net, prim->GetStream(), net_args); + cpu_stream.reset(CreateStream(ctx, prim->GetEngine())); + execute_primitives(net, cpu_stream, net_args); #else net.push_back(FindOrCreateReorder(reorder_memory_, user_memory_)); ExecutePrimitive(net, NET_ARGS_PTR, *cpu_engine_); @@ -1870,9 +1901,12 @@ class MklDnnData { class MklPrimitive { public: virtual ~MklPrimitive() {} - + MklPrimitive() {} + MklPrimitive(const engine& cpu_engine) { cpu_engine_ = cpu_engine; } // Dummy data which MKL DNN never operates on unsigned char* DummyData = nullptr; + engine cpu_engine_ = engine(ENGINE_CPU, 0); + const engine& GetEngine() { return cpu_engine_; } }; const mkldnn::memory::dims NONE_DIMS = {}; @@ -2058,7 +2092,8 @@ class FactoryKeyCreator { class MklReorderPrimitive : public MklPrimitive { public: - explicit MklReorderPrimitive(const memory* from, const memory* to) { + explicit MklReorderPrimitive(const memory* from, const memory* to) + : MklPrimitive(engine(ENGINE_CPU, 0)) { Setup(from, to); } ~MklReorderPrimitive() {} @@ -2081,7 +2116,6 @@ class MklReorderPrimitive : public MklPrimitive { : src_mem(nullptr), dst_mem(nullptr), reorder_prim(nullptr) {} } context_; - engine cpu_engine_ = engine(ENGINE_CPU, 0); std::shared_ptr stream_; void Setup(const memory* from, const memory* to) {