Addressing review comments
This commit is contained in:
parent
b3b18d5e50
commit
eaa9fed5a1
@ -164,11 +164,7 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||||
|
|
||||||
context_.src_mpd.reset(
|
context_.src_mpd.reset(
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md));
|
|
||||||
#else
|
|
||||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_));
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Create an eltwise forward descriptor and primitive descriptor
|
// Create an eltwise forward descriptor and primitive descriptor
|
||||||
context_.fwd_desc.reset(new eltwise_forward::desc(
|
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||||
@ -210,7 +206,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
const MklEltwiseFwdParams<T>& fwdParams) {
|
const MklEltwiseFwdParams<T>& fwdParams) {
|
||||||
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
|
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
// Get a eltwise fwd primitive from the cached pool
|
// Get a eltwise fwd primitive from the cached pool
|
||||||
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
||||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
|
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
|
||||||
@ -220,20 +215,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
||||||
fwdParams, eltwise_forward);
|
fwdParams, eltwise_forward);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
auto src_fmt =
|
|
||||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
|
||||||
|
|
||||||
// Get a eltwise fwd primitive from the cached pool
|
|
||||||
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
|
||||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
|
|
||||||
src_fmt));
|
|
||||||
if (eltwise_forward == nullptr) {
|
|
||||||
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
|
|
||||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
|
||||||
fwdParams, src_fmt, eltwise_forward);
|
|
||||||
}
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
return eltwise_forward;
|
return eltwise_forward;
|
||||||
}
|
}
|
||||||
@ -247,7 +228,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
MklEltwiseFwdPrimitiveFactory() {}
|
MklEltwiseFwdPrimitiveFactory() {}
|
||||||
~MklEltwiseFwdPrimitiveFactory() {}
|
~MklEltwiseFwdPrimitiveFactory() {}
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
|
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||||
string prefix = "eltwise_fwd";
|
string prefix = "eltwise_fwd";
|
||||||
FactoryKeyCreator key_creator;
|
FactoryKeyCreator key_creator;
|
||||||
@ -256,6 +236,9 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
||||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
||||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
return key_creator.GetKey();
|
return key_creator.GetKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,32 +252,6 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
string key = CreateKey(fwdParams);
|
string key = CreateKey(fwdParams);
|
||||||
this->SetOp(key, op);
|
this->SetOp(key, op);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
|
|
||||||
memory::format src_fmt) {
|
|
||||||
string prefix = "eltwise_fwd";
|
|
||||||
FactoryKeyCreator key_creator;
|
|
||||||
key_creator.AddAsKey(prefix);
|
|
||||||
key_creator.AddAsKey(fwdParams.src_dims);
|
|
||||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
|
||||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
|
||||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
|
||||||
key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
|
|
||||||
return key_creator.GetKey();
|
|
||||||
}
|
|
||||||
|
|
||||||
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
|
||||||
memory::format src_fmt) {
|
|
||||||
string key = CreateKey(fwdParams, src_fmt);
|
|
||||||
return this->GetOp(key);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
|
||||||
memory::format src_fmt, MklPrimitive* op) {
|
|
||||||
string key = CreateKey(fwdParams, src_fmt);
|
|
||||||
this->SetOp(key, op);
|
|
||||||
}
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -441,16 +398,10 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
|||||||
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||||
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
context_.src_mpd.reset(new MEMORY_PRIMITIVE_DESC(*context_.src_md));
|
|
||||||
context_.diff_dst_mpd.reset(
|
|
||||||
new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md));
|
|
||||||
#else
|
|
||||||
context_.src_mpd.reset(
|
context_.src_mpd.reset(
|
||||||
new MEMORY_PRIMITIVE_DESC(*context_.src_md, cpu_engine_));
|
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||||
context_.diff_dst_mpd.reset(
|
context_.diff_dst_mpd.reset(
|
||||||
new MEMORY_PRIMITIVE_DESC(*context_.diff_dst_md, cpu_engine_));
|
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
|
||||||
#endif
|
|
||||||
|
|
||||||
// Create forward eltwise primitive.
|
// Create forward eltwise primitive.
|
||||||
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
||||||
@ -508,7 +459,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
const MklEltwiseBwdParams<T>& bwdParams) {
|
const MklEltwiseBwdParams<T>& bwdParams) {
|
||||||
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
|
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
|
||||||
|
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
// try to find a suitable one in pool
|
// try to find a suitable one in pool
|
||||||
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
||||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
||||||
@ -519,24 +469,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
||||||
bwdParams, eltwise_backward);
|
bwdParams, eltwise_backward);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
auto src_fmt =
|
|
||||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
|
||||||
auto diff_dst_fmt =
|
|
||||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
|
||||||
|
|
||||||
// try to find a suitable one in pool
|
|
||||||
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
|
||||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
|
||||||
bwdParams, src_fmt, diff_dst_fmt));
|
|
||||||
|
|
||||||
if (eltwise_backward == nullptr) {
|
|
||||||
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
|
|
||||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
|
||||||
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
|
|
||||||
}
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
|
|
||||||
return eltwise_backward;
|
return eltwise_backward;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -546,7 +478,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
#ifdef ENABLE_MKLDNN_V1
|
|
||||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
|
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||||
string prefix = "eltwise_bwd";
|
string prefix = "eltwise_bwd";
|
||||||
FactoryKeyCreator key_creator;
|
FactoryKeyCreator key_creator;
|
||||||
@ -555,6 +486,9 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
||||||
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
||||||
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
return key_creator.GetKey();
|
return key_creator.GetKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -568,36 +502,6 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
string key = CreateKey(bwdParams);
|
string key = CreateKey(bwdParams);
|
||||||
this->SetOp(key, op);
|
this->SetOp(key, op);
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
|
|
||||||
const memory::format& src_fmt,
|
|
||||||
const memory::format& diff_dst_fmt) {
|
|
||||||
string prefix = "eltwise_bwd";
|
|
||||||
FactoryKeyCreator key_creator;
|
|
||||||
key_creator.AddAsKey(prefix);
|
|
||||||
key_creator.AddAsKey(bwdParams.src_dims);
|
|
||||||
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
|
||||||
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
|
||||||
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
|
||||||
key_creator.AddAsKey(static_cast<int>(src_fmt));
|
|
||||||
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
|
|
||||||
return key_creator.GetKey();
|
|
||||||
}
|
|
||||||
|
|
||||||
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
|
||||||
const memory::format& src_fmt,
|
|
||||||
const memory::format& diff_dst_fmt) {
|
|
||||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
|
||||||
return this->GetOp(key);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
|
||||||
const memory::format& src_fmt,
|
|
||||||
const memory::format& diff_dst_fmt, MklPrimitive* op) {
|
|
||||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
|
||||||
this->SetOp(key, op);
|
|
||||||
}
|
|
||||||
#endif // ENABLE_MKLDNN_V1
|
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
@ -77,6 +77,8 @@ namespace tensorflow {
|
|||||||
memory::desc({dims}, MklDnnType<type>(), fm)
|
memory::desc({dims}, MklDnnType<type>(), fm)
|
||||||
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
||||||
#define MEMORY_PRIMITIVE_DESC memory::desc
|
#define MEMORY_PRIMITIVE_DESC memory::desc
|
||||||
|
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
|
||||||
|
MEMORY_PRIMITIVE_DESC(md)
|
||||||
#define MKL_FMT_TAG mkl_fmt_tag
|
#define MKL_FMT_TAG mkl_fmt_tag
|
||||||
#define MKL_TENSOR_FORMAT MklTensorFormat
|
#define MKL_TENSOR_FORMAT MklTensorFormat
|
||||||
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
||||||
@ -169,6 +171,8 @@ namespace tensorflow {
|
|||||||
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
||||||
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
||||||
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
|
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
|
||||||
|
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
|
||||||
|
MEMORY_PRIMITIVE_DESC(md, engine)
|
||||||
#define MKL_FMT_TAG tf_fmt
|
#define MKL_FMT_TAG tf_fmt
|
||||||
#define MKL_TENSOR_FORMAT memory::format
|
#define MKL_TENSOR_FORMAT memory::format
|
||||||
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
||||||
|
Loading…
Reference in New Issue
Block a user