Addressing review comments

This commit is contained in:
Mahmoud Abuzaina 2020-02-09 17:14:44 -08:00
parent b3b18d5e50
commit eaa9fed5a1
2 changed files with 13 additions and 105 deletions

View File

@ -164,11 +164,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
context_.src_mpd.reset(
#ifdef ENABLE_MKLDNN_V1
new MEMORY_PRIMITIVE_DESC(*context_.src_md));
#else
new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_));
#endif
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
// Create an eltwise forward descriptor and primitive descriptor
context_.fwd_desc.reset(new eltwise_forward::desc(
@ -210,7 +206,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklEltwiseFwdParams<T>& fwdParams) {
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
#ifdef ENABLE_MKLDNN_V1
// Get a eltwise fwd primitive from the cached pool
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
@ -220,20 +215,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
fwdParams, eltwise_forward);
}
#else
auto src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
// Get a eltwise fwd primitive from the cached pool
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
src_fmt));
if (eltwise_forward == nullptr) {
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
fwdParams, src_fmt, eltwise_forward);
}
#endif // ENABLE_MKLDNN_V1
return eltwise_forward;
}
@ -247,7 +228,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklEltwiseFwdPrimitiveFactory() {}
~MklEltwiseFwdPrimitiveFactory() {}
#ifdef ENABLE_MKLDNN_V1
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
string prefix = "eltwise_fwd";
FactoryKeyCreator key_creator;
@ -256,6 +236,9 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
#endif // !ENABLE_MKLDNN_V1
return key_creator.GetKey();
}
@ -269,32 +252,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
string key = CreateKey(fwdParams);
this->SetOp(key, op);
}
#else
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt) {
string prefix = "eltwise_fwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(fwdParams.src_dims);
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
return key_creator.GetKey();
}
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt) {
string key = CreateKey(fwdParams, src_fmt);
return this->GetOp(key);
}
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt, MklPrimitive* op) {
string key = CreateKey(fwdParams, src_fmt);
this->SetOp(key, op);
}
#endif // ENABLE_MKLDNN_V1
};
template <typename T>
@ -441,16 +398,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
#ifdef ENABLE_MKLDNN_V1
context_.src_mpd.reset(new MEMORY_PRIMITIVE_DESC(*context_.src_md));
context_.diff_dst_mpd.reset(
new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md));
#else
context_.src_mpd.reset(
new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_));
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
context_.diff_dst_mpd.reset(
new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md, cpu_engine_));
#endif
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
// Create forward eltwise primitive.
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
@ -508,7 +459,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklEltwiseBwdParams<T>& bwdParams) {
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
#ifdef ENABLE_MKLDNN_V1
// try to find a suitable one in pool
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
@ -519,24 +469,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
bwdParams, eltwise_backward);
}
#else
auto src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
auto diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
// try to find a suitable one in pool
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
bwdParams, src_fmt, diff_dst_fmt));
if (eltwise_backward == nullptr) {
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
}
#endif // ENABLE_MKLDNN_V1
return eltwise_backward;
}
@ -546,7 +478,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
private:
#ifdef ENABLE_MKLDNN_V1
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
string prefix = "eltwise_bwd";
FactoryKeyCreator key_creator;
@ -555,6 +486,9 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
#endif // !ENABLE_MKLDNN_V1
return key_creator.GetKey();
}
@ -568,36 +502,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
string key = CreateKey(bwdParams);
this->SetOp(key, op);
}
#else
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
string prefix = "eltwise_bwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(bwdParams.src_dims);
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
key_creator.AddAsKey(static_cast<int>(src_fmt));
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
return key_creator.GetKey();
}
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
return this->GetOp(key);
}
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt, MklPrimitive* op) {
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
this->SetOp(key, op);
}
#endif // ENABLE_MKLDNN_V1
};
typedef Eigen::ThreadPoolDevice CPUDevice;

View File

@ -77,6 +77,8 @@ namespace tensorflow {
memory::desc({dims}, MklDnnType<type>(), fm)
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
#define MEMORY_PRIMITIVE_DESC memory::desc
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
MEMORY_PRIMITIVE_DESC(md)
#define MKL_FMT_TAG mkl_fmt_tag
#define MKL_TENSOR_FORMAT MklTensorFormat
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
@ -169,6 +171,8 @@ namespace tensorflow {
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
MEMORY_PRIMITIVE_DESC(md, engine)
#define MKL_FMT_TAG tf_fmt
#define MKL_TENSOR_FORMAT memory::format
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked