From eaa9fed5a1a18b92f36635607db8aaee66152cec Mon Sep 17 00:00:00 2001 From: Mahmoud Abuzaina Date: Sun, 9 Feb 2020 17:14:44 -0800 Subject: [PATCH] Addressing review comments --- tensorflow/core/kernels/mkl_relu_op.cc | 114 ++----------------------- tensorflow/core/util/mkl_types.h | 4 + 2 files changed, 13 insertions(+), 105 deletions(-) diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index a9ab6a84669..d4c90c0f2a0 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -164,11 +164,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); context_.src_mpd.reset( -#ifdef ENABLE_MKLDNN_V1 - new MEMORY_PRIMITIVE_DESC(*context_.src_md)); -#else - new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_)); -#endif + new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_)); // Create an eltwise forward descriptor and primitive descriptor context_.fwd_desc.reset(new eltwise_forward::desc( @@ -210,7 +206,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { const MklEltwiseFwdParams& fwdParams) { MklEltwiseFwdPrimitive* eltwise_forward = nullptr; -#ifdef ENABLE_MKLDNN_V1 // Get a eltwise fwd primitive from the cached pool eltwise_forward = static_cast*>( MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd( @@ -220,20 +215,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( fwdParams, eltwise_forward); } -#else - auto src_fmt = - static_cast(fwdParams.src_md.data.format); - - // Get a eltwise fwd primitive from the cached pool - eltwise_forward = static_cast*>( - MklEltwiseFwdPrimitiveFactory::GetInstance().GetEltwiseFwd(fwdParams, - src_fmt)); - if (eltwise_forward == nullptr) { - eltwise_forward = new MklEltwiseFwdPrimitive(fwdParams); - MklEltwiseFwdPrimitiveFactory::GetInstance().SetEltwiseFwd( - fwdParams, src_fmt, eltwise_forward); - } -#endif // ENABLE_MKLDNN_V1 return eltwise_forward; } @@ -247,7 +228,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseFwdPrimitiveFactory() {} ~MklEltwiseFwdPrimitiveFactory() {} -#ifdef ENABLE_MKLDNN_V1 static string CreateKey(const MklEltwiseFwdParams& fwdParams) { string prefix = "eltwise_fwd"; FactoryKeyCreator key_creator; @@ -256,6 +236,9 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); key_creator.AddAsKey(static_cast(fwdParams.alpha)); key_creator.AddAsKey(static_cast(fwdParams.beta)); +#ifndef ENABLE_MKLDNN_V1 + key_creator.AddAsKey(static_cast(fwdParams.src_md.data.format)); +#endif // !ENABLE_MKLDNN_V1 return key_creator.GetKey(); } @@ -269,32 +252,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory { string key = CreateKey(fwdParams); this->SetOp(key, op); } -#else - static string CreateKey(const MklEltwiseFwdParams& fwdParams, - memory::format src_fmt) { - string prefix = "eltwise_fwd"; - FactoryKeyCreator key_creator; - key_creator.AddAsKey(prefix); - key_creator.AddAsKey(fwdParams.src_dims); - key_creator.AddAsKey(static_cast(fwdParams.alg_kind)); - key_creator.AddAsKey(static_cast(fwdParams.alpha)); - key_creator.AddAsKey(static_cast(fwdParams.beta)); - key_creator.AddAsKey(static_cast(src_fmt)); - return key_creator.GetKey(); - } - - MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, - memory::format src_fmt) { - string key = CreateKey(fwdParams, src_fmt); - return this->GetOp(key); - } - - void SetEltwiseFwd(const MklEltwiseFwdParams& fwdParams, - memory::format src_fmt, MklPrimitive* op) { - string key = CreateKey(fwdParams, src_fmt); - this->SetOp(key, op); - } -#endif // ENABLE_MKLDNN_V1 }; template @@ -441,16 +398,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); -#ifdef ENABLE_MKLDNN_V1 - context_.src_mpd.reset(new MEMORY_PRIMITIVE_DESC(*context_.src_md)); - context_.diff_dst_mpd.reset( - new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md)); -#else context_.src_mpd.reset( - new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_)); + new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_)); context_.diff_dst_mpd.reset( - new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md, cpu_engine_)); -#endif + new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_)); // Create forward eltwise primitive. context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( @@ -508,7 +459,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { const MklEltwiseBwdParams& bwdParams) { MklEltwiseBwdPrimitive* eltwise_backward = nullptr; -#ifdef ENABLE_MKLDNN_V1 // try to find a suitable one in pool eltwise_backward = static_cast*>( MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( @@ -519,24 +469,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( bwdParams, eltwise_backward); } -#else - auto src_fmt = - static_cast(bwdParams.common_md.data.format); - auto diff_dst_fmt = - static_cast(bwdParams.common_md.data.format); - - // try to find a suitable one in pool - eltwise_backward = static_cast*>( - MklEltwiseBwdPrimitiveFactory::GetInstance().GetEltwiseBwd( - bwdParams, src_fmt, diff_dst_fmt)); - - if (eltwise_backward == nullptr) { - eltwise_backward = new MklEltwiseBwdPrimitive(bwdParams); - MklEltwiseBwdPrimitiveFactory::GetInstance().SetEltwiseBwd( - bwdParams, src_fmt, diff_dst_fmt, eltwise_backward); - } -#endif // ENABLE_MKLDNN_V1 - return eltwise_backward; } @@ -546,7 +478,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { } private: -#ifdef ENABLE_MKLDNN_V1 static string CreateKey(const MklEltwiseBwdParams& bwdParams) { string prefix = "eltwise_bwd"; FactoryKeyCreator key_creator; @@ -555,6 +486,9 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); key_creator.AddAsKey(static_cast(bwdParams.alpha)); key_creator.AddAsKey(static_cast(bwdParams.beta)); +#ifndef ENABLE_MKLDNN_V1 + key_creator.AddAsKey(static_cast(bwdParams.common_md.data.format)); +#endif // !ENABLE_MKLDNN_V1 return key_creator.GetKey(); } @@ -568,36 +502,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory { string key = CreateKey(bwdParams); this->SetOp(key, op); } -#else - static string CreateKey(const MklEltwiseBwdParams& bwdParams, - const memory::format& src_fmt, - const memory::format& diff_dst_fmt) { - string prefix = "eltwise_bwd"; - FactoryKeyCreator key_creator; - key_creator.AddAsKey(prefix); - key_creator.AddAsKey(bwdParams.src_dims); - key_creator.AddAsKey(static_cast(bwdParams.alg_kind)); - key_creator.AddAsKey(static_cast(bwdParams.alpha)); - key_creator.AddAsKey(static_cast(bwdParams.beta)); - key_creator.AddAsKey(static_cast(src_fmt)); - key_creator.AddAsKey(static_cast(diff_dst_fmt)); - return key_creator.GetKey(); - } - - MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, - const memory::format& src_fmt, - const memory::format& diff_dst_fmt) { - string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); - return this->GetOp(key); - } - - void SetEltwiseBwd(const MklEltwiseBwdParams& bwdParams, - const memory::format& src_fmt, - const memory::format& diff_dst_fmt, MklPrimitive* op) { - string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt); - this->SetOp(key, op); - } -#endif // ENABLE_MKLDNN_V1 }; typedef Eigen::ThreadPoolDevice CPUDevice; diff --git a/tensorflow/core/util/mkl_types.h b/tensorflow/core/util/mkl_types.h index c88ff37f53c..5738d27ff5b 100644 --- a/tensorflow/core/util/mkl_types.h +++ b/tensorflow/core/util/mkl_types.h @@ -77,6 +77,8 @@ namespace tensorflow { memory::desc({dims}, MklDnnType(), fm) #define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine #define MEMORY_PRIMITIVE_DESC memory::desc +#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \ + MEMORY_PRIMITIVE_DESC(md) #define MKL_FMT_TAG mkl_fmt_tag #define MKL_TENSOR_FORMAT MklTensorFormat #define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED @@ -169,6 +171,8 @@ namespace tensorflow { memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine) #define MEMORY_PD_WITHOUT_DATA(pd, engine) pd #define MEMORY_PRIMITIVE_DESC memory::primitive_desc +#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \ + MEMORY_PRIMITIVE_DESC(md, engine) #define MKL_FMT_TAG tf_fmt #define MKL_TENSOR_FORMAT memory::format #define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked