Addressing review comments
This commit is contained in:
parent
b3b18d5e50
commit
eaa9fed5a1
@ -164,11 +164,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||
|
||||
context_.src_mpd.reset(
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md));
|
||||
#else
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_));
|
||||
#endif
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
|
||||
// Create an eltwise forward descriptor and primitive descriptor
|
||||
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||
@ -210,7 +206,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
const MklEltwiseFwdParams<T>& fwdParams) {
|
||||
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Get a eltwise fwd primitive from the cached pool
|
||||
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
|
||||
@ -220,20 +215,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
||||
fwdParams, eltwise_forward);
|
||||
}
|
||||
#else
|
||||
auto src_fmt =
|
||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
||||
|
||||
// Get a eltwise fwd primitive from the cached pool
|
||||
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
|
||||
src_fmt));
|
||||
if (eltwise_forward == nullptr) {
|
||||
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
||||
fwdParams, src_fmt, eltwise_forward);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
return eltwise_forward;
|
||||
}
|
||||
@ -247,7 +228,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
MklEltwiseFwdPrimitiveFactory() {}
|
||||
~MklEltwiseFwdPrimitiveFactory() {}
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||
string prefix = "eltwise_fwd";
|
||||
FactoryKeyCreator key_creator;
|
||||
@ -256,6 +236,9 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
@ -269,32 +252,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
string key = CreateKey(fwdParams);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
#else
|
||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
|
||||
memory::format src_fmt) {
|
||||
string prefix = "eltwise_fwd";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
key_creator.AddAsKey(fwdParams.src_dims);
|
||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
||||
key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
||||
memory::format src_fmt) {
|
||||
string key = CreateKey(fwdParams, src_fmt);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
||||
memory::format src_fmt, MklPrimitive* op) {
|
||||
string key = CreateKey(fwdParams, src_fmt);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -441,16 +398,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
context_.src_mpd.reset(new MEMORY_PRIMITIVE_DESC(*context_.src_md));
|
||||
context_.diff_dst_mpd.reset(
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md));
|
||||
#else
|
||||
context_.src_mpd.reset(
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_));
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
context_.diff_dst_mpd.reset(
|
||||
new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md, cpu_engine_));
|
||||
#endif
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
|
||||
|
||||
// Create forward eltwise primitive.
|
||||
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
||||
@ -508,7 +459,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
const MklEltwiseBwdParams<T>& bwdParams) {
|
||||
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// try to find a suitable one in pool
|
||||
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
||||
@ -519,24 +469,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
||||
bwdParams, eltwise_backward);
|
||||
}
|
||||
#else
|
||||
auto src_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
auto diff_dst_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
|
||||
// try to find a suitable one in pool
|
||||
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
||||
bwdParams, src_fmt, diff_dst_fmt));
|
||||
|
||||
if (eltwise_backward == nullptr) {
|
||||
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
|
||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
||||
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
return eltwise_backward;
|
||||
}
|
||||
|
||||
@ -546,7 +478,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
}
|
||||
|
||||
private:
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||
string prefix = "eltwise_bwd";
|
||||
FactoryKeyCreator key_creator;
|
||||
@ -555,6 +486,9 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
||||
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
||||
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
@ -568,36 +502,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
string key = CreateKey(bwdParams);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
#else
|
||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
|
||||
const memory::format& src_fmt,
|
||||
const memory::format& diff_dst_fmt) {
|
||||
string prefix = "eltwise_bwd";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
key_creator.AddAsKey(bwdParams.src_dims);
|
||||
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
||||
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
||||
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
||||
key_creator.AddAsKey(static_cast<int>(src_fmt));
|
||||
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
||||
const memory::format& src_fmt,
|
||||
const memory::format& diff_dst_fmt) {
|
||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
||||
const memory::format& src_fmt,
|
||||
const memory::format& diff_dst_fmt, MklPrimitive* op) {
|
||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
};
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
@ -77,6 +77,8 @@ namespace tensorflow {
|
||||
memory::desc({dims}, MklDnnType<type>(), fm)
|
||||
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
||||
#define MEMORY_PRIMITIVE_DESC memory::desc
|
||||
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
|
||||
MEMORY_PRIMITIVE_DESC(md)
|
||||
#define MKL_FMT_TAG mkl_fmt_tag
|
||||
#define MKL_TENSOR_FORMAT MklTensorFormat
|
||||
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
||||
@ -169,6 +171,8 @@ namespace tensorflow {
|
||||
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
||||
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
||||
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
|
||||
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
|
||||
MEMORY_PRIMITIVE_DESC(md, engine)
|
||||
#define MKL_FMT_TAG tf_fmt
|
||||
#define MKL_TENSOR_FORMAT memory::format
|
||||
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
||||
|
Loading…
Reference in New Issue
Block a user