From b3b18d5e507eedae4d3cd85a1842a2b85df80893 Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Tue, 4 Feb 2020 09:32:48 -0800 Subject: [PATCH 1/2] DNNL 1.x integration for Relu op --- tensorflow/core/graph/mkl_layout_pass.cc | 4 +- tensorflow/core/kernels/mkl_relu_op.cc | 406 ++++++++++++++++------- 2 files changed, 290 insertions(+), 120 deletions(-) diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 551193262e2..2662ae71cf5 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -654,6 +654,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), CopyAttrsAll, QuantizeOpRewrite, kRewriteForLayoutPropagation}); +#endif // !ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -666,10 +667,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); +#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back( {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); -#endif // !ENABLE_MKLDNN_V1 // Disable these two MKL operators for now due to some test failures caused // by these two ops /* @@ -682,7 +683,6 @@ rinfo_.push_back({csinfo_.tanh_grad, CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); */ -#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back( {csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index 48ee5e0de3f..a9ab6a84669 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -16,31 +16,33 @@ limitations under the License. // See docs in ../ops/nn_ops.cc. #ifdef INTEL_MKL +#include + #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::algorithm; -using mkldnn::eltwise_bounded_relu; -using mkldnn::eltwise_elu; using mkldnn::eltwise_forward; -using mkldnn::eltwise_relu; -using mkldnn::eltwise_tanh; using mkldnn::memory; using mkldnn::prop_kind; using mkldnn::stream; +using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc; +using EltwiseBwdPd = mkldnn::eltwise_backward::primitive_desc; + namespace tensorflow { template class MklEltwiseFwdParams { public: - memory::dims src_dims; // check if this is needed + memory::dims src_dims; memory::desc src_md; algorithm alg_kind; float alpha; @@ -59,11 +61,12 @@ template class MklEltwiseFwdPrimitive : public MklPrimitive { public: explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams& fwdParams) - : cpu_engine_(engine::cpu, 0) { - // store expected format + : cpu_engine_(ENGINE_CPU, 0) { +#ifndef ENABLE_MKLDNN_V1 context_.src_fmt = static_cast(fwdParams.src_md.data.format); - context_.fwd_stream.reset(new stream(stream::kind::eager)); +#endif + context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); // create eltwise primitive if (context_.eltwise_fwd == nullptr) { @@ -80,24 +83,38 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); - context_.fwd_stream->submit(context_.fwd_primitives); - // after execution, set data handle back +#ifdef ENABLE_MKLDNN_V1 + DCHECK_EQ(context_.fwd_primitives.size(), + context_.fwd_primitives_args.size()); + for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { + context_.fwd_primitives.at(i).execute(*context_.fwd_stream, + context_.fwd_primitives_args.at(i)); + } +#else + context_.fwd_stream->submit(context_.fwd_primitives); +#endif + + // After execution, set data handle back. context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); } - std::shared_ptr GetEltwiseFwdPd() { - return context_.fwd_pd; - } + std::shared_ptr GetEltwiseFwdPd() { return context_.fwd_pd; } +#ifndef ENABLE_MKLDNN_V1 + // In MKL-DNN v1.x, memory format tags only provide a partial description + // of the memory layout. Hence, these functions are disabled for v1.x. memory::format GetSrcMemoryFormat() { return context_.src_fmt; } +#endif private: // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh struct EltwiseFwdContext { - // expected memory format for this primitive instance +#ifndef ENABLE_MKLDNN_V1 + // Expected memory format for this primitive instance mkldnn::memory::format src_fmt; +#endif // MKLDNN memory std::shared_ptr src_mem; @@ -105,14 +122,14 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { // desc & prmitive desc std::shared_ptr fwd_desc; - std::shared_ptr fwd_pd; + std::shared_ptr fwd_pd; // memory desc std::shared_ptr src_md; std::shared_ptr dst_md; // memory primitive desc - std::shared_ptr src_mpd; + std::shared_ptr src_mpd; // Eltwise primitive std::shared_ptr eltwise_fwd; @@ -120,8 +137,15 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { std::shared_ptr fwd_stream; std::vector fwd_primitives; +#ifdef ENABLE_MKLDNN_V1 + std::vector> fwd_primitives_args; +#endif + EltwiseFwdContext() - : src_fmt(memory::format::any), + : +#ifndef ENABLE_MKLDNN_V1 + src_fmt(memory::format::any), +#endif src_mem(nullptr), dst_mem(nullptr), fwd_desc(nullptr), @@ -130,31 +154,47 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { dst_md(nullptr), src_mpd(nullptr), eltwise_fwd(nullptr), - fwd_stream(nullptr) {} + fwd_stream(nullptr) { + } }; // Eltwise forward primitive setup void Setup(const MklEltwiseFwdParams& fwdParams) { // create memory descriptors for eltwise data with specified format context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); - context_.src_mpd.reset( - new memory::primitive_desc(*context_.src_md, cpu_engine_)); - // create a eltwise - context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( + context_.src_mpd.reset( +#ifdef ENABLE_MKLDNN_V1 + new MEMORY_PRIMITIVE_DESC(*context_.src_md)); +#else + new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_)); +#endif + + // Create an eltwise forward descriptor and primitive descriptor + context_.fwd_desc.reset(new eltwise_forward::desc( prop_kind::forward, fwdParams.alg_kind, *context_.src_md, fwdParams.alpha, fwdParams.beta)); - context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( - *context_.fwd_desc, cpu_engine_)); + context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_)); + auto fwd_pd = context_.fwd_pd.get(); - // create memory primitive based on dummy data +#ifdef ENABLE_MKLDNN_V1 + // Create memory primitive based on dummy data + context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC, + cpu_engine_, DummyData)); + context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_DST, + cpu_engine_, DummyData)); + // Create eltwise primitive and add it to net + context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd)); + context_.fwd_primitives_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); +#else context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); context_.dst_mem.reset( new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); - - // create eltwise primitive and add it to net - context_.eltwise_fwd.reset(new mkldnn::eltwise_forward( + context_.eltwise_fwd.reset(new eltwise_forward( *context_.fwd_pd, *context_.src_mem, *context_.dst_mem)); +#endif context_.fwd_primitives.push_back(*context_.eltwise_fwd); } @@ -170,6 +210,17 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { const MklEltwiseFwdParams& fwdParams) { MklEltwiseFwdPrimitive* eltwise_forward = nullptr; +#ifdef ENABLE_MKLDNN_V1 + // Get a eltwise fwd primitive from the cached pool + eltwise_forward = static_cast*>( + MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd( + fwdParams)); + if (eltwise_forward == nullptr) { + eltwise_forward = new MklEltwiseFwdPrimitive(fwdParams); + MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( + fwdParams, eltwise_forward); + } +#else auto src_fmt = static_cast(fwdParams.src_md.data.format); @@ -182,6 +233,8 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( fwdParams, src_fmt, eltwise_forward); } +#endif // ENABLE_MKLDNN_V1 + return eltwise_forward; } @@ -194,6 +247,29 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseFwdPrimitiveFactory() {} ~MklEltwiseFwdPrimitiveFactory() {} +#ifdef ENABLE_MKLDNN_V1 + static string CreateKey(const MklEltwiseFwdParams& fwdParams) { + string prefix = "eltwise_fwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(fwdParams.src_dims); + key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); + key_creator.AddAsKey(static_cast(fwdParams.alpha)); + key_creator.AddAsKey(static_cast(fwdParams.beta)); + return key_creator.GetKey(); + } + + MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams& fwdParams) { + string key = CreateKey(fwdParams); + return this->GetOp(key); + } + + void SetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, + MklPrimitive* op) { + string key = CreateKey(fwdParams); + this->SetOp(key, op); + } +#else static string CreateKey(const MklEltwiseFwdParams& fwdParams, memory::format src_fmt) { string prefix = "eltwise_fwd"; @@ -218,6 +294,7 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { string key = CreateKey(fwdParams, src_fmt); this->SetOp(key, op); } +#endif // ENABLE_MKLDNN_V1 }; template @@ -243,12 +320,14 @@ template class MklEltwiseBwdPrimitive : public MklPrimitive { public: explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams& bwdParams) - : cpu_engine_(engine::cpu, 0) { + : cpu_engine_(ENGINE_CPU, 0) { +#ifndef ENABLE_MKLDNN_V1 context_.src_fmt = static_cast(bwdParams.common_md.data.format); context_.diff_dst_fmt = static_cast(bwdParams.common_md.data.format); - context_.bwd_stream.reset(new stream(stream::kind::eager)); +#endif + context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_))); // create eltwise primitive if (context_.eltwise_bwd == nullptr) { Setup(bwdParams); @@ -267,7 +346,17 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); + +#ifdef ENABLE_MKLDNN_V1 + DCHECK_EQ(context_.bwd_primitives.size(), + context_.bwd_primitives_args.size()); + for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) { + context_.bwd_primitives.at(i).execute(*context_.bwd_stream, + context_.bwd_primitives_args.at(i)); + } +#else context_.bwd_stream->submit(context_.bwd_primitives); +#endif // ENABLE_MKLDNN_V1 // after execution, set data handle back context_.src_mem->set_data_handle(DummyData); @@ -275,52 +364,61 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { context_.diff_src_mem->set_data_handle(DummyData); } - std::shared_ptr GetEltwiseBwdPd() { - return context_.bwd_pd; - } + std::shared_ptr GetEltwiseBwdPd() { return context_.bwd_pd; } +#ifndef ENABLE_MKLDNN_V1 memory::format GetSrcMemoryFormat() { return context_.src_fmt; } - memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; } +#endif // !ENABLE_MKLDNN_V1 private: // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh struct EltwiseBwdContext { - // expected memory format for this primitive instance +#ifndef ENABLE_MKLDNN_V1 memory::format src_fmt; memory::format diff_dst_fmt; +#endif // MKLDNN memory std::shared_ptr src_mem; std::shared_ptr diff_dst_mem; std::shared_ptr diff_src_mem; - // desc & prmitive desc + // Backward Eltwise descriptor. std::shared_ptr bwd_desc; - // memory desc + // Memory descriptors. std::shared_ptr src_md; std::shared_ptr diff_dst_md; std::shared_ptr common_md; - // memory primitive desc - std::shared_ptr src_mpd; - std::shared_ptr diff_dst_mpd; + // Memory primitive descriptor. + // TODO(gzmkl): for MKL-DNN 1.0, src_mpd is same as src_md + // So it should be removed once MKL-DNN 0.x is cleaned. + std::shared_ptr src_mpd; + std::shared_ptr diff_dst_mpd; - // fwd primitive desc + // Forward and backward descriptors and primitive descriptors. std::shared_ptr fwd_desc; - std::shared_ptr fwd_pd; - std::shared_ptr bwd_pd; + std::shared_ptr fwd_pd; + std::shared_ptr bwd_pd; - // Eltwise primitive + // Eltwise primitive. std::shared_ptr eltwise_bwd; std::shared_ptr bwd_stream; std::vector bwd_primitives; +#ifdef ENABLE_MKLDNN_V1 + std::vector bwd_primitives_args; +#endif // ENABLE_MKLDNN_V1 + EltwiseBwdContext() - : src_fmt(memory::format::any), + : +#ifndef ENABLE_MKLDNN_V1 + src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), +#endif // !ENABLE_MKLDNN_V1 src_mem(nullptr), diff_dst_mem(nullptr), diff_src_mem(nullptr), @@ -333,42 +431,64 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { fwd_pd(nullptr), bwd_pd(nullptr), eltwise_bwd(nullptr), - bwd_stream(nullptr) {} + bwd_stream(nullptr) { + } }; // Eltwise backward primitive setup void Setup(const MklEltwiseBwdParams& bwdParams) { - // create memory descriptors for eltwise data w/ no specified format + // Create memory descriptors for eltwise data w/ no specified format context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); - context_.src_mpd.reset( - new memory::primitive_desc(*context_.src_md, cpu_engine_)); +#ifdef ENABLE_MKLDNN_V1 + context_.src_mpd.reset(new MEMORY_PRIMITIVE_DESC(*context_.src_md)); context_.diff_dst_mpd.reset( - new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_)); + new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md)); +#else + context_.src_mpd.reset( + new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_)); + context_.diff_dst_mpd.reset( + new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md, cpu_engine_)); +#endif - // create forward eltwise primitive + // Create forward eltwise primitive. context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md, bwdParams.alpha, bwdParams.beta)); - context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( - *context_.fwd_desc, cpu_engine_)); + context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_)); context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc( bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md, bwdParams.alpha, bwdParams.beta)); - context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc( - *context_.bwd_desc, cpu_engine_, *context_.fwd_pd)); + context_.bwd_pd.reset( + new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd)); - // create memory primitive based on dummy data + auto bwd_pd = context_.bwd_pd.get(); + +#ifdef ENABLE_MKLDNN_V1 + // Create memory primitive based on dummy data. + context_.src_mem.reset(new MEMORY_CONSTRUCTOR(bwd_pd->PRIMITIVE_DESC_SRC, + cpu_engine_, DummyData)); + context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR( + bwd_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData)); + context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR( + bwd_pd->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_, DummyData)); + // Create eltwise primitive and add it to net. + context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd)); + context_.bwd_primitives_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem}, + { MKLDNN_ARG_DIFF_SRC, + *context_.diff_src_mem }}); +#else context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData)); context_.diff_src_mem.reset(new memory( context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); - - // create eltwise primitive and add it to net context_.eltwise_bwd.reset(new mkldnn::eltwise_backward( *context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem, *context_.diff_src_mem)); +#endif // ENABLE_MKLDNN_V1 context_.bwd_primitives.push_back(*context_.eltwise_bwd); } @@ -388,6 +508,18 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { const MklEltwiseBwdParams& bwdParams) { MklEltwiseBwdPrimitive* eltwise_backward = nullptr; +#ifdef ENABLE_MKLDNN_V1 + // try to find a suitable one in pool + eltwise_backward = static_cast*>( + MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( + bwdParams)); + + if (eltwise_backward == nullptr) { + eltwise_backward = new MklEltwiseBwdPrimitive(bwdParams); + MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( + bwdParams, eltwise_backward); + } +#else auto src_fmt = static_cast(bwdParams.common_md.data.format); auto diff_dst_fmt = @@ -403,6 +535,8 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( bwdParams, src_fmt, diff_dst_fmt, eltwise_backward); } +#endif // ENABLE_MKLDNN_V1 + return eltwise_backward; } @@ -412,6 +546,29 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { } private: +#ifdef ENABLE_MKLDNN_V1 + static string CreateKey(const MklEltwiseBwdParams& bwdParams) { + string prefix = "eltwise_bwd"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(bwdParams.src_dims); + key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); + key_creator.AddAsKey(static_cast(bwdParams.alpha)); + key_creator.AddAsKey(static_cast(bwdParams.beta)); + return key_creator.GetKey(); + } + + MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams& bwdParams) { + string key = CreateKey(bwdParams); + return this->GetOp(key); + } + + void SetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, + MklPrimitive* op) { + string key = CreateKey(bwdParams); + this->SetOp(key, op); + } +#else static string CreateKey(const MklEltwiseBwdParams& bwdParams, const memory::format& src_fmt, const memory::format& diff_dst_fmt) { @@ -440,6 +597,7 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); this->SetOp(key, op); } +#endif // ENABLE_MKLDNN_V1 }; typedef Eigen::ThreadPoolDevice CPUDevice; @@ -481,7 +639,7 @@ class MklReluOpBase : public OpKernel { // Set DNN primitive - src MklDnnData src(&cpu_engine); memory::dims src_dims; - memory::desc src_md({}, memory::data_undef, memory::format_undef); + memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF); if (dnn_shape_src.IsMklTensor()) { src_md = dnn_shape_src.GetMklLayout(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); @@ -492,31 +650,29 @@ class MklReluOpBase : public OpKernel { src_md = MklDnnData::CreateBlockedMemDesc(src_dims, src_strides); } - // get a eltwise fwd from primitive pool + // Try to get an eltwise forward primitive from caching pool MklEltwiseFwdParams fwdParams(src_dims, src_md, alg_kind, alpha_, beta_); + MklEltwiseFwdPrimitive* eltwise_fwd = MklEltwiseFwdPrimitiveFactory::Get(fwdParams); - // prepare for execuation + auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd(); + + // Check if src needs to be reordered const T* src_data = src_tensor.flat().data(); - // check wehther src need to reorder - if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) { + if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) { src.SetUsrMem(src_md, &src_tensor); - auto src_target_pd = memory::primitive_desc( - {{src_dims}, MklDnnType(), eltwise_fwd->GetSrcMemoryFormat()}, - cpu_engine); - src.CheckReorderToOpMem(src_target_pd); + src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine)); src_data = const_cast( reinterpret_cast(src.GetOpMem().get_data_handle())); } - // allocate dst tensor, always set it as MKL-DNN layout - std::shared_ptr eltwise_fwd_pd = - eltwise_fwd->GetEltwiseFwdPd(); + // Allocate dst tensor, always set it as MKL-DNN layout if (dnn_shape_src.IsMklTensor()) { dnn_shape_dst.SetMklTensor(true); - auto dst_pd = eltwise_fwd_pd->dst_primitive_desc(); + auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST; dnn_shape_dst.SetMklLayout(&dst_pd); dnn_shape_dst.SetElemType(MklDnnType()); dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), @@ -539,9 +695,9 @@ class MklReluOpBase : public OpKernel { // execute eltwise eltwise_fwd->Execute(src_data, dst_data); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -549,8 +705,8 @@ class MklReluOpBase : public OpKernel { } private: - engine cpu_engine = engine(engine::cpu, 0); - std::shared_ptr relu_fwd_pd; + engine cpu_engine = engine(ENGINE_CPU, 0); + std::shared_ptr relu_fwd_pd; protected: float alpha_; @@ -604,8 +760,8 @@ class MklReluGradOpBase : public OpKernel { // get a eltwise bwd from primitive pool memory::dims src_dims = {}; - memory::desc src_md({}, memory::data_undef, memory::format_undef); - memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); + memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF); + memory::desc diff_dst_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF); if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { src_dims = TFShapeToMklDnnDims(src_tensor.shape()); auto src_strides = CalculateTFStrides(src_dims); @@ -616,18 +772,18 @@ class MklReluGradOpBase : public OpKernel { src_md = dnn_shape_src.GetMklLayout(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); - memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); + MKL_TENSOR_FORMAT src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); auto src_tf_data_format = MklDnnDataFormatToTFDataFormat(src_mkl_data_format); auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), src_tf_data_format); - diff_dst_md = - memory::desc(diff_dst_dims, MklDnnType(), src_mkl_data_format); + diff_dst_md = memory::desc(diff_dst_dims, MklDnnType(), + GET_TENSOR_FORMAT(src_mkl_data_format)); } else if (!dnn_shape_src.IsMklTensor() && dnn_shape_diff_dst.IsMklTensor()) { diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); - memory::format diff_dst_mkl_data_format = + MKL_TENSOR_FORMAT diff_dst_mkl_data_format = dnn_shape_diff_dst.GetTfDataFormat(); auto diff_dst_tf_data_format = MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); @@ -637,8 +793,8 @@ class MklReluGradOpBase : public OpKernel { diff_dst_tf_data_format) : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), diff_dst_tf_data_format); - src_md = - memory::desc(src_dims, MklDnnType(), diff_dst_mkl_data_format); + src_md = memory::desc(src_dims, MklDnnType(), + GET_TENSOR_FORMAT(diff_dst_mkl_data_format)); } else { src_md = dnn_shape_src.GetMklLayout(); diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); @@ -649,7 +805,7 @@ class MklReluGradOpBase : public OpKernel { // format. So we set common memory descriptor in MKL format, if any of the // inputs are in MKL format. Let's get memory descriptor that we will use // for both the inputs. - memory::desc common_md({}, memory::data_undef, memory::format_undef); + memory::desc common_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF); if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md; } else { @@ -662,30 +818,32 @@ class MklReluGradOpBase : public OpKernel { beta_); MklEltwiseBwdPrimitive* eltwise_bwd = MklEltwiseBwdPrimitiveFactory::Get(bwdParams); + auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); // check whether need reorder for src / diff_dst const T* src_data = src_tensor.flat().data(); - if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) { + if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) { src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem( - eltwise_bwd_pd.get()->diff_src_primitive_desc()); + src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine)); src_data = const_cast( reinterpret_cast(src.GetOpMem().get_data_handle())); } const T* diff_dst_data = diff_dst_tensor.flat().data(); - if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) { + if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd, + eltwise_bwd)) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - diff_dst.CheckReorderToOpMem( - eltwise_bwd_pd.get()->diff_src_primitive_desc()); + diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine)); diff_dst_data = const_cast( reinterpret_cast(diff_dst.GetOpMem().get_data_handle())); } // allocate diff_src tensor if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { - auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc(); + auto diff_src_pd = eltwise_bwd_pd->PRIMITIVE_DESC_DIFF_SRC; dnn_shape_diff_src.SetMklTensor(true); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); @@ -716,9 +874,9 @@ class MklReluGradOpBase : public OpKernel { // execute eltwise bwd eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -726,8 +884,8 @@ class MklReluGradOpBase : public OpKernel { } private: - engine cpu_engine = engine(engine::cpu, 0); - std::shared_ptr relu_fwd_pd; + engine cpu_engine = engine(ENGINE_CPU, 0); + std::shared_ptr relu_fwd_pd; protected: float alpha_; @@ -735,12 +893,13 @@ class MklReluGradOpBase : public OpKernel { }; template -class MklReluOp : public MklReluOpBase { +class MklReluOp : public MklReluOpBase { public: ~MklReluOp() {} explicit MklReluOp(OpKernelConstruction* context) - : MklReluOpBase(context, 0.0f, 0.0f) {} + : MklReluOpBase(context, 0.0f, 0.0f) { + } virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor @@ -764,12 +923,14 @@ class MklReluOp : public MklReluOpBase { }; template -class MklReluGradOp : public MklReluGradOpBase { +class MklReluGradOp + : public MklReluGradOpBase { public: ~MklReluGradOp() {} explicit MklReluGradOp(OpKernelConstruction* context) - : MklReluGradOpBase(context, 0.0f, 0.0f) {} + : MklReluGradOpBase(context, 0.0f, + 0.0f) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor @@ -799,12 +960,12 @@ class MklReluGradOp : public MklReluGradOpBase { }; template -class MklEluOp : public MklReluOpBase { +class MklEluOp : public MklReluOpBase { public: ~MklEluOp() {} explicit MklEluOp(OpKernelConstruction* context) - : MklReluOpBase(context, 0.0f, 0.0f) {} + : MklReluOpBase(context, 0.0f, 0.0f) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor @@ -832,12 +993,14 @@ class MklEluOp : public MklReluOpBase { }; template -class MklEluGradOp : public MklReluGradOpBase { +class MklEluGradOp + : public MklReluGradOpBase { public: ~MklEluGradOp() {} explicit MklEluGradOp(OpKernelConstruction* context) - : MklReluGradOpBase(context, 0.0f, 0.0f) {} + : MklReluGradOpBase(context, 0.0f, + 0.0f) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor @@ -872,12 +1035,13 @@ class MklEluGradOp : public MklReluGradOpBase { }; template -class MklTanhOp : public MklReluOpBase { +class MklTanhOp : public MklReluOpBase { public: ~MklTanhOp() {} explicit MklTanhOp(OpKernelConstruction* context) - : MklReluOpBase(context, 0.0f, 0.0f) {} + : MklReluOpBase(context, 0.0f, 0.0f) { + } virtual void Compute_Scalar(OpKernelContext* context) { const size_t src_index = 0; // index of src input tensor @@ -904,12 +1068,14 @@ class MklTanhOp : public MklReluOpBase { }; template -class MklTanhGradOp : public MklReluGradOpBase { +class MklTanhGradOp + : public MklReluGradOpBase { public: ~MklTanhGradOp() {} explicit MklTanhGradOp(OpKernelConstruction* context) - : MklReluGradOpBase(context, 0.0f, 0.0f) {} + : MklReluGradOpBase(context, 0.0f, + 0.0f) {} virtual void Compute_Scalar(OpKernelContext* context) { const size_t diff_dst_index = 0; // index of diff_dst input tensor @@ -943,12 +1109,13 @@ class MklTanhGradOp : public MklReluGradOpBase { #define RELU6_UPPER_BOUND 6.0f template -class MklRelu6Op : public MklReluOpBase { +class MklRelu6Op + : public MklReluOpBase { public: ~MklRelu6Op() {} explicit MklRelu6Op(OpKernelConstruction* context) - : MklReluOpBase( + : MklReluOpBase( context, RELU6_UPPER_BOUND, 0.0f) {} virtual void Compute_Scalar(OpKernelContext* context) { @@ -973,12 +1140,12 @@ class MklRelu6Op : public MklReluOpBase { template class MklRelu6GradOp - : public MklReluGradOpBase { + : public MklReluGradOpBase { public: ~MklRelu6GradOp() {} explicit MklRelu6GradOp(OpKernelConstruction* context) - : MklReluGradOpBase( + : MklReluGradOpBase( context, RELU6_UPPER_BOUND, 0.0f) {} virtual void Compute_Scalar(OpKernelContext* context) { @@ -1007,12 +1174,13 @@ class MklRelu6GradOp }; template -class MklLeakyReluOp : public MklReluOpBase { +class MklLeakyReluOp + : public MklReluOpBase { public: ~MklLeakyReluOp() {} explicit MklLeakyReluOp(OpKernelConstruction* context) - : MklReluOpBase(context, 0.0f, 0.0f) { + : MklReluOpBase(context, 0.0f, 0.0f) { float alpha; OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha)); OP_REQUIRES( @@ -1044,12 +1212,14 @@ class MklLeakyReluOp : public MklReluOpBase { }; template -class MklLeakyReluGradOp : public MklReluGradOpBase { +class MklLeakyReluGradOp + : public MklReluGradOpBase { public: ~MklLeakyReluGradOp() {} explicit MklLeakyReluGradOp(OpKernelConstruction* context) - : MklReluGradOpBase(context, 0.0f, 0.0f) { + : MklReluGradOpBase(context, 0.0f, + 0.0f) { float alpha; OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha)); OP_REQUIRES( From eaa9fed5a1a18b92f36635607db8aaee66152cec Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Sun, 9 Feb 2020 17:14:44 -0800 Subject: [PATCH 2/2] Addressing review comments --- tensorflow/core/kernels/mkl_relu_op.cc | 114 ++----------------------- tensorflow/core/util/mkl_types.h | 4 + 2 files changed, 13 insertions(+), 105 deletions(-) diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index a9ab6a84669..d4c90c0f2a0 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -164,11 +164,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); context_.src_mpd.reset( -#ifdef ENABLE_MKLDNN_V1 - new MEMORY_PRIMITIVE_DESC(*context_.src_md)); -#else - new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_)); -#endif + new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_)); // Create an eltwise forward descriptor and primitive descriptor context_.fwd_desc.reset(new eltwise_forward::desc( @@ -210,7 +206,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { const MklEltwiseFwdParams& fwdParams) { MklEltwiseFwdPrimitive* eltwise_forward = nullptr; -#ifdef ENABLE_MKLDNN_V1 // Get a eltwise fwd primitive from the cached pool eltwise_forward = static_cast*>( MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd( @@ -220,20 +215,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( fwdParams, eltwise_forward); } -#else - auto src_fmt = - static_cast(fwdParams.src_md.data.format); - - // Get a eltwise fwd primitive from the cached pool - eltwise_forward = static_cast*>( - MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd(fwdParams, - src_fmt)); - if (eltwise_forward == nullptr) { - eltwise_forward = new MklEltwiseFwdPrimitive(fwdParams); - MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( - fwdParams, src_fmt, eltwise_forward); - } -#endif // ENABLE_MKLDNN_V1 return eltwise_forward; } @@ -247,7 +228,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseFwdPrimitiveFactory() {} ~MklEltwiseFwdPrimitiveFactory() {} -#ifdef ENABLE_MKLDNN_V1 static string CreateKey(const MklEltwiseFwdParams& fwdParams) { string prefix = "eltwise_fwd"; FactoryKeyCreator key_creator; @@ -256,6 +236,9 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); key_creator.AddAsKey(static_cast(fwdParams.alpha)); key_creator.AddAsKey(static_cast(fwdParams.beta)); +#ifndef ENABLE_MKLDNN_V1 + key_creator.AddAsKey(static_cast(fwdParams.src_md.data.format)); +#endif // !ENABLE_MKLDNN_V1 return key_creator.GetKey(); } @@ -269,32 +252,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { string key = CreateKey(fwdParams); this->SetOp(key, op); } -#else - static string CreateKey(const MklEltwiseFwdParams& fwdParams, - memory::format src_fmt) { - string prefix = "eltwise_fwd"; - FactoryKeyCreator key_creator; - key_creator.AddAsKey(prefix); - key_creator.AddAsKey(fwdParams.src_dims); - key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); - key_creator.AddAsKey(static_cast(fwdParams.alpha)); - key_creator.AddAsKey(static_cast(fwdParams.beta)); - key_creator.AddAsKey(static_cast(src_fmt)); - return key_creator.GetKey(); - } - - MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, - memory::format src_fmt) { - string key = CreateKey(fwdParams, src_fmt); - return this->GetOp(key); - } - - void SetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, - memory::format src_fmt, MklPrimitive* op) { - string key = CreateKey(fwdParams, src_fmt); - this->SetOp(key, op); - } -#endif // ENABLE_MKLDNN_V1 }; template @@ -441,16 +398,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); -#ifdef ENABLE_MKLDNN_V1 - context_.src_mpd.reset(new MEMORY_PRIMITIVE_DESC(*context_.src_md)); - context_.diff_dst_mpd.reset( - new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md)); -#else context_.src_mpd.reset( - new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_)); + new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_)); context_.diff_dst_mpd.reset( - new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md, cpu_engine_)); -#endif + new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_)); // Create forward eltwise primitive. context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( @@ -508,7 +459,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { const MklEltwiseBwdParams& bwdParams) { MklEltwiseBwdPrimitive* eltwise_backward = nullptr; -#ifdef ENABLE_MKLDNN_V1 // try to find a suitable one in pool eltwise_backward = static_cast*>( MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( @@ -519,24 +469,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( bwdParams, eltwise_backward); } -#else - auto src_fmt = - static_cast(bwdParams.common_md.data.format); - auto diff_dst_fmt = - static_cast(bwdParams.common_md.data.format); - - // try to find a suitable one in pool - eltwise_backward = static_cast*>( - MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( - bwdParams, src_fmt, diff_dst_fmt)); - - if (eltwise_backward == nullptr) { - eltwise_backward = new MklEltwiseBwdPrimitive(bwdParams); - MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( - bwdParams, src_fmt, diff_dst_fmt, eltwise_backward); - } -#endif // ENABLE_MKLDNN_V1 - return eltwise_backward; } @@ -546,7 +478,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { } private: -#ifdef ENABLE_MKLDNN_V1 static string CreateKey(const MklEltwiseBwdParams& bwdParams) { string prefix = "eltwise_bwd"; FactoryKeyCreator key_creator; @@ -555,6 +486,9 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); key_creator.AddAsKey(static_cast(bwdParams.alpha)); key_creator.AddAsKey(static_cast(bwdParams.beta)); +#ifndef ENABLE_MKLDNN_V1 + key_creator.AddAsKey(static_cast(bwdParams.common_md.data.format)); +#endif // !ENABLE_MKLDNN_V1 return key_creator.GetKey(); } @@ -568,36 +502,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { string key = CreateKey(bwdParams); this->SetOp(key, op); } -#else - static string CreateKey(const MklEltwiseBwdParams& bwdParams, - const memory::format& src_fmt, - const memory::format& diff_dst_fmt) { - string prefix = "eltwise_bwd"; - FactoryKeyCreator key_creator; - key_creator.AddAsKey(prefix); - key_creator.AddAsKey(bwdParams.src_dims); - key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); - key_creator.AddAsKey(static_cast(bwdParams.alpha)); - key_creator.AddAsKey(static_cast(bwdParams.beta)); - key_creator.AddAsKey(static_cast(src_fmt)); - key_creator.AddAsKey(static_cast(diff_dst_fmt)); - return key_creator.GetKey(); - } - - MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, - const memory::format& src_fmt, - const memory::format& diff_dst_fmt) { - string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); - return this->GetOp(key); - } - - void SetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, - const memory::format& src_fmt, - const memory::format& diff_dst_fmt, MklPrimitive* op) { - string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); - this->SetOp(key, op); - } -#endif // ENABLE_MKLDNN_V1 }; typedef Eigen::ThreadPoolDevice CPUDevice; diff --git a/tensorflow/core/util/mkl_types.h b/tensorflow/core/util/mkl_types.h index c88ff37f53c..5738d27ff5b 100644 --- a/tensorflow/core/util/mkl_types.h +++ b/tensorflow/core/util/mkl_types.h @@ -77,6 +77,8 @@ namespace tensorflow { memory::desc({dims}, MklDnnType(), fm) #define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine #define MEMORY_PRIMITIVE_DESC memory::desc +#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \ + MEMORY_PRIMITIVE_DESC(md) #define MKL_FMT_TAG mkl_fmt_tag #define MKL_TENSOR_FORMAT MklTensorFormat #define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED @@ -169,6 +171,8 @@ namespace tensorflow { memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine) #define MEMORY_PD_WITHOUT_DATA(pd, engine) pd #define MEMORY_PRIMITIVE_DESC memory::primitive_desc +#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \ + MEMORY_PRIMITIVE_DESC(md, engine) #define MKL_FMT_TAG tf_fmt #define MKL_TENSOR_FORMAT memory::format #define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked