Merge pull request #36469 from Intel-tensorflow:mabuzain/relu_1.x

PiperOrigin-RevId: 294344071
Change-Id: I739ce33c0f7617e53bbc95d68f883e422e497a41
This commit is contained in:
TensorFlower Gardener 2020-02-10 17:42:22 -08:00
commit 61d19090c8
3 changed files with 222 additions and 145 deletions

View File

@ -654,6 +654,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
CopyAttrsAll, QuantizeOpRewrite, CopyAttrsAll, QuantizeOpRewrite,
kRewriteForLayoutPropagation}); kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsAll, AlwaysRewrite, CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation}); kRewriteForLayoutPropagation});
@ -666,10 +667,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back( rinfo_.push_back(
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad), {csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back( rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
// Disable these two MKL operators for now due to some test failures caused // Disable these two MKL operators for now due to some test failures caused
// by these two ops // by these two ops
/* /*
@ -682,7 +683,6 @@ rinfo_.push_back({csinfo_.tanh_grad,
CopyAttrsAll, AlwaysRewrite, CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation}); kRewriteForLayoutPropagation});
*/ */
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back( rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), {csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});

View File

@ -16,6 +16,8 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc. // See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#include <unordered_map>
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
@ -23,24 +25,24 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/mkl_util.h"
using mkldnn::algorithm; using mkldnn::algorithm;
using mkldnn::eltwise_bounded_relu;
using mkldnn::eltwise_elu;
using mkldnn::eltwise_forward; using mkldnn::eltwise_forward;
using mkldnn::eltwise_relu;
using mkldnn::eltwise_tanh;
using mkldnn::memory; using mkldnn::memory;
using mkldnn::prop_kind; using mkldnn::prop_kind;
using mkldnn::stream; using mkldnn::stream;
using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc;
using EltwiseBwdPd = mkldnn::eltwise_backward::primitive_desc;
namespace tensorflow { namespace tensorflow {
template <typename T> template <typename T>
class MklEltwiseFwdParams { class MklEltwiseFwdParams {
public: public:
memory::dims src_dims; // check if this is needed memory::dims src_dims;
memory::desc src_md; memory::desc src_md;
algorithm alg_kind; algorithm alg_kind;
float alpha; float alpha;
@ -59,11 +61,12 @@ template <typename T>
class MklEltwiseFwdPrimitive : public MklPrimitive { class MklEltwiseFwdPrimitive : public MklPrimitive {
public: public:
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams) explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
: cpu_engine_(engine::cpu, 0) { : cpu_engine_(ENGINE_CPU, 0) {
// store expected format #ifndef ENABLE_MKLDNN_V1
context_.src_fmt = context_.src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format); static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
context_.fwd_stream.reset(new stream(stream::kind::eager)); #endif
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
// create eltwise primitive // create eltwise primitive
if (context_.eltwise_fwd == nullptr) { if (context_.eltwise_fwd == nullptr) {
@ -80,24 +83,38 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.src_mem->set_data_handle( context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data))); static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
context_.fwd_stream->submit(context_.fwd_primitives);
// after execution, set data handle back #ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.fwd_primitives_args.at(i));
}
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif
// After execution, set data handle back.
context_.src_mem->set_data_handle(DummyData); context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData);
} }
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() { std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
return context_.fwd_pd;
}
#ifndef ENABLE_MKLDNN_V1
// In MKL-DNN v1.x, memory format tags only provide a partial description
// of the memory layout. Hence, these functions are disabled for v1.x.
memory::format GetSrcMemoryFormat() { return context_.src_fmt; } memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
#endif
private: private:
// Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
struct EltwiseFwdContext { struct EltwiseFwdContext {
// expected memory format for this primitive instance #ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance
mkldnn::memory::format src_fmt; mkldnn::memory::format src_fmt;
#endif
// MKLDNN memory // MKLDNN memory
std::shared_ptr<memory> src_mem; std::shared_ptr<memory> src_mem;
@ -105,14 +122,14 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// desc & prmitive desc // desc & prmitive desc
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc; std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd; std::shared_ptr<EltwiseFwdPd> fwd_pd;
// memory desc // memory desc
std::shared_ptr<memory::desc> src_md; std::shared_ptr<memory::desc> src_md;
std::shared_ptr<memory::desc> dst_md; std::shared_ptr<memory::desc> dst_md;
// memory primitive desc // memory primitive desc
std::shared_ptr<memory::primitive_desc> src_mpd; std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
// Eltwise primitive // Eltwise primitive
std::shared_ptr<mkldnn::primitive> eltwise_fwd; std::shared_ptr<mkldnn::primitive> eltwise_fwd;
@ -120,8 +137,15 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
std::shared_ptr<stream> fwd_stream; std::shared_ptr<stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives; std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
#endif
EltwiseFwdContext() EltwiseFwdContext()
: src_fmt(memory::format::any), :
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
#endif
src_mem(nullptr), src_mem(nullptr),
dst_mem(nullptr), dst_mem(nullptr),
fwd_desc(nullptr), fwd_desc(nullptr),
@ -130,31 +154,43 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
dst_md(nullptr), dst_md(nullptr),
src_mpd(nullptr), src_mpd(nullptr),
eltwise_fwd(nullptr), eltwise_fwd(nullptr),
fwd_stream(nullptr) {} fwd_stream(nullptr) {
}
}; };
// Eltwise forward primitive setup // Eltwise forward primitive setup
void Setup(const MklEltwiseFwdParams<T>& fwdParams) { void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
// create memory descriptors for eltwise data with specified format // create memory descriptors for eltwise data with specified format
context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
context_.src_mpd.reset(
new memory::primitive_desc(*context_.src_md, cpu_engine_));
// create a eltwise context_.src_mpd.reset(
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
// Create an eltwise forward descriptor and primitive descriptor
context_.fwd_desc.reset(new eltwise_forward::desc(
prop_kind::forward, fwdParams.alg_kind, *context_.src_md, prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
fwdParams.alpha, fwdParams.beta)); fwdParams.alpha, fwdParams.beta));
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
*context_.fwd_desc, cpu_engine_)); auto fwd_pd = context_.fwd_pd.get();
// create memory primitive based on dummy data #ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_DST,
cpu_engine_, DummyData));
// Create eltwise primitive and add it to net
context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
#else
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
context_.dst_mem.reset( context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
context_.eltwise_fwd.reset(new eltwise_forward(
// create eltwise primitive and add it to net
context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem)); *context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
#endif
context_.fwd_primitives.push_back(*context_.eltwise_fwd); context_.fwd_primitives.push_back(*context_.eltwise_fwd);
} }
@ -170,18 +206,16 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklEltwiseFwdParams<T>& fwdParams) { const MklEltwiseFwdParams<T>& fwdParams) {
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr; MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
auto src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
// Get a eltwise fwd primitive from the cached pool // Get a eltwise fwd primitive from the cached pool
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>( eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams, MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
src_fmt)); fwdParams));
if (eltwise_forward == nullptr) { if (eltwise_forward == nullptr) {
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams); eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd( MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
fwdParams, src_fmt, eltwise_forward); fwdParams, eltwise_forward);
} }
return eltwise_forward; return eltwise_forward;
} }
@ -194,8 +228,7 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklEltwiseFwdPrimitiveFactory() {} MklEltwiseFwdPrimitiveFactory() {}
~MklEltwiseFwdPrimitiveFactory() {} ~MklEltwiseFwdPrimitiveFactory() {}
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams, static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
memory::format src_fmt) {
string prefix = "eltwise_fwd"; string prefix = "eltwise_fwd";
FactoryKeyCreator key_creator; FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix); key_creator.AddAsKey(prefix);
@ -203,19 +236,20 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind)); key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha)); key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta)); key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
key_creator.AddAsKey<int>(static_cast<int>(src_fmt)); #ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
#endif // !ENABLE_MKLDNN_V1
return key_creator.GetKey(); return key_creator.GetKey();
} }
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams, MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
memory::format src_fmt) { string key = CreateKey(fwdParams);
string key = CreateKey(fwdParams, src_fmt);
return this->GetOp(key); return this->GetOp(key);
} }
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams, void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt, MklPrimitive* op) { MklPrimitive* op) {
string key = CreateKey(fwdParams, src_fmt); string key = CreateKey(fwdParams);
this->SetOp(key, op); this->SetOp(key, op);
} }
}; };
@ -243,12 +277,14 @@ template <typename T>
class MklEltwiseBwdPrimitive : public MklPrimitive { class MklEltwiseBwdPrimitive : public MklPrimitive {
public: public:
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams) explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
: cpu_engine_(engine::cpu, 0) { : cpu_engine_(ENGINE_CPU, 0) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt = context_.src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format); static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
context_.diff_dst_fmt = context_.diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format); static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
context_.bwd_stream.reset(new stream(stream::kind::eager)); #endif
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
// create eltwise primitive // create eltwise primitive
if (context_.eltwise_bwd == nullptr) { if (context_.eltwise_bwd == nullptr) {
Setup(bwdParams); Setup(bwdParams);
@ -267,7 +303,17 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.diff_dst_mem->set_data_handle( context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data))); static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.bwd_primitives.size(),
context_.bwd_primitives_args.size());
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
context_.bwd_primitives_args.at(i));
}
#else
context_.bwd_stream->submit(context_.bwd_primitives); context_.bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// after execution, set data handle back // after execution, set data handle back
context_.src_mem->set_data_handle(DummyData); context_.src_mem->set_data_handle(DummyData);
@ -275,52 +321,61 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.diff_src_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData);
} }
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() { std::shared_ptr<EltwiseBwdPd> GetEltwiseBwdPd() { return context_.bwd_pd; }
return context_.bwd_pd;
}
#ifndef ENABLE_MKLDNN_V1
memory::format GetSrcMemoryFormat() { return context_.src_fmt; } memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; } memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
#endif // !ENABLE_MKLDNN_V1
private: private:
// Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
struct EltwiseBwdContext { struct EltwiseBwdContext {
// expected memory format for this primitive instance #ifndef ENABLE_MKLDNN_V1
memory::format src_fmt; memory::format src_fmt;
memory::format diff_dst_fmt; memory::format diff_dst_fmt;
#endif
// MKLDNN memory // MKLDNN memory
std::shared_ptr<memory> src_mem; std::shared_ptr<memory> src_mem;
std::shared_ptr<memory> diff_dst_mem; std::shared_ptr<memory> diff_dst_mem;
std::shared_ptr<memory> diff_src_mem; std::shared_ptr<memory> diff_src_mem;
// desc & prmitive desc // Backward Eltwise descriptor.
std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc; std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
// memory desc // Memory descriptors.
std::shared_ptr<memory::desc> src_md; std::shared_ptr<memory::desc> src_md;
std::shared_ptr<memory::desc> diff_dst_md; std::shared_ptr<memory::desc> diff_dst_md;
std::shared_ptr<memory::desc> common_md; std::shared_ptr<memory::desc> common_md;
// memory primitive desc // Memory primitive descriptor.
std::shared_ptr<memory::primitive_desc> src_mpd; // TODO(gzmkl): for MKL-DNN 1.0, src_mpd is same as src_md
std::shared_ptr<memory::primitive_desc> diff_dst_mpd; // So it should be removed once MKL-DNN 0.x is cleaned.
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
std::shared_ptr<MEMORY_PRIMITIVE_DESC> diff_dst_mpd;
// fwd primitive desc // Forward and backward descriptors and primitive descriptors.
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc; std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd; std::shared_ptr<EltwiseFwdPd> fwd_pd;
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd; std::shared_ptr<EltwiseBwdPd> bwd_pd;
// Eltwise primitive // Eltwise primitive.
std::shared_ptr<mkldnn::primitive> eltwise_bwd; std::shared_ptr<mkldnn::primitive> eltwise_bwd;
std::shared_ptr<stream> bwd_stream; std::shared_ptr<stream> bwd_stream;
std::vector<mkldnn::primitive> bwd_primitives; std::vector<mkldnn::primitive> bwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<MemoryArgsMap> bwd_primitives_args;
#endif // ENABLE_MKLDNN_V1
EltwiseBwdContext() EltwiseBwdContext()
: src_fmt(memory::format::any), :
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
diff_dst_fmt(memory::format::any), diff_dst_fmt(memory::format::any),
#endif // !ENABLE_MKLDNN_V1
src_mem(nullptr), src_mem(nullptr),
diff_dst_mem(nullptr), diff_dst_mem(nullptr),
diff_src_mem(nullptr), diff_src_mem(nullptr),
@ -333,42 +388,58 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
fwd_pd(nullptr), fwd_pd(nullptr),
bwd_pd(nullptr), bwd_pd(nullptr),
eltwise_bwd(nullptr), eltwise_bwd(nullptr),
bwd_stream(nullptr) {} bwd_stream(nullptr) {
}
}; };
// Eltwise backward primitive setup // Eltwise backward primitive setup
void Setup(const MklEltwiseBwdParams<T>& bwdParams) { void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
// create memory descriptors for eltwise data w/ no specified format // Create memory descriptors for eltwise data w/ no specified format
context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
context_.src_mpd.reset( context_.src_mpd.reset(
new memory::primitive_desc(*context_.src_md, cpu_engine_)); new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
context_.diff_dst_mpd.reset( context_.diff_dst_mpd.reset(
new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_)); new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
// create forward eltwise primitive // Create forward eltwise primitive.
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc( context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md, prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
bwdParams.alpha, bwdParams.beta)); bwdParams.alpha, bwdParams.beta));
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc( context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
*context_.fwd_desc, cpu_engine_));
context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc( context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md, bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
bwdParams.alpha, bwdParams.beta)); bwdParams.alpha, bwdParams.beta));
context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc( context_.bwd_pd.reset(
*context_.bwd_desc, cpu_engine_, *context_.fwd_pd)); new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
// create memory primitive based on dummy data auto bwd_pd = context_.bwd_pd.get();
#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(bwd_pd->PRIMITIVE_DESC_SRC,
cpu_engine_, DummyData));
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData));
context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_pd->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_, DummyData));
// Create eltwise primitive and add it to net.
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd));
context_.bwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
{ MKLDNN_ARG_DIFF_SRC,
*context_.diff_src_mem }});
#else
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData)); context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData)); context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
context_.diff_src_mem.reset(new memory( context_.diff_src_mem.reset(new memory(
context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData)); context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
// create eltwise primitive and add it to net
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward( context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
*context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem, *context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
*context_.diff_src_mem)); *context_.diff_src_mem));
#endif // ENABLE_MKLDNN_V1
context_.bwd_primitives.push_back(*context_.eltwise_bwd); context_.bwd_primitives.push_back(*context_.eltwise_bwd);
} }
@ -388,20 +459,15 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklEltwiseBwdParams<T>& bwdParams) { const MklEltwiseBwdParams<T>& bwdParams) {
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr; MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
auto src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
auto diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
// try to find a suitable one in pool // try to find a suitable one in pool
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>( eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd( MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
bwdParams, src_fmt, diff_dst_fmt)); bwdParams));
if (eltwise_backward == nullptr) { if (eltwise_backward == nullptr) {
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams); eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd( MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward); bwdParams, eltwise_backward);
} }
return eltwise_backward; return eltwise_backward;
} }
@ -412,9 +478,7 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
} }
private: private:
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams, static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
string prefix = "eltwise_bwd"; string prefix = "eltwise_bwd";
FactoryKeyCreator key_creator; FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix); key_creator.AddAsKey(prefix);
@ -422,22 +486,20 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind)); key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha)); key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
key_creator.AddAsKey(static_cast<float>(bwdParams.beta)); key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
key_creator.AddAsKey(static_cast<int>(src_fmt)); #ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt)); key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
#endif // !ENABLE_MKLDNN_V1
return key_creator.GetKey(); return key_creator.GetKey();
} }
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams, MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams) {
const memory::format& src_fmt, string key = CreateKey(bwdParams);
const memory::format& diff_dst_fmt) {
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
return this->GetOp(key); return this->GetOp(key);
} }
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams, void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt, MklPrimitive* op) {
const memory::format& diff_dst_fmt, MklPrimitive* op) { string key = CreateKey(bwdParams);
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
this->SetOp(key, op); this->SetOp(key, op);
} }
}; };
@ -481,7 +543,7 @@ class MklReluOpBase : public OpKernel {
// Set DNN primitive - src // Set DNN primitive - src
MklDnnData<T> src(&cpu_engine); MklDnnData<T> src(&cpu_engine);
memory::dims src_dims; memory::dims src_dims;
memory::desc src_md({}, memory::data_undef, memory::format_undef); memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
if (dnn_shape_src.IsMklTensor()) { if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout(); src_md = dnn_shape_src.GetMklLayout();
src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
@ -492,31 +554,29 @@ class MklReluOpBase : public OpKernel {
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides); src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
} }
// get a eltwise fwd from primitive pool // Try to get an eltwise forward primitive from caching pool
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_, MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
beta_); beta_);
MklEltwiseFwdPrimitive<T>* eltwise_fwd = MklEltwiseFwdPrimitive<T>* eltwise_fwd =
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams); MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
// prepare for execuation auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
// Check if src needs to be reordered
const T* src_data = src_tensor.flat<T>().data(); const T* src_data = src_tensor.flat<T>().data();
// check wehther src need to reorder if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor); src.SetUsrMem(src_md, &src_tensor);
auto src_target_pd = memory::primitive_desc( src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
{{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()}, eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
cpu_engine);
src.CheckReorderToOpMem(src_target_pd);
src_data = const_cast<T*>( src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
} }
// allocate dst tensor, always set it as MKL-DNN layout // Allocate dst tensor, always set it as MKL-DNN layout
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> eltwise_fwd_pd =
eltwise_fwd->GetEltwiseFwdPd();
if (dnn_shape_src.IsMklTensor()) { if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true); dnn_shape_dst.SetMklTensor(true);
auto dst_pd = eltwise_fwd_pd->dst_primitive_desc(); auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST;
dnn_shape_dst.SetMklLayout(&dst_pd); dnn_shape_dst.SetMklLayout(&dst_pd);
dnn_shape_dst.SetElemType(MklDnnType<T>()); dnn_shape_dst.SetElemType(MklDnnType<T>());
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(), dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
@ -549,8 +609,8 @@ class MklReluOpBase : public OpKernel {
} }
private: private:
engine cpu_engine = engine(engine::cpu, 0); engine cpu_engine = engine(ENGINE_CPU, 0);
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd; std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
protected: protected:
float alpha_; float alpha_;
@ -604,8 +664,8 @@ class MklReluGradOpBase : public OpKernel {
// get a eltwise bwd from primitive pool // get a eltwise bwd from primitive pool
memory::dims src_dims = {}; memory::dims src_dims = {};
memory::desc src_md({}, memory::data_undef, memory::format_undef); memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef); memory::desc diff_dst_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) { if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
src_dims = TFShapeToMklDnnDims(src_tensor.shape()); src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims); auto src_strides = CalculateTFStrides(src_dims);
@ -616,18 +676,18 @@ class MklReluGradOpBase : public OpKernel {
src_md = dnn_shape_src.GetMklLayout(); src_md = dnn_shape_src.GetMklLayout();
src_dims = dnn_shape_src.GetSizesAsMklDnnDims(); src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat(); MKL_TENSOR_FORMAT src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
auto src_tf_data_format = auto src_tf_data_format =
MklDnnDataFormatToTFDataFormat(src_mkl_data_format); MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
src_tf_data_format); src_tf_data_format);
diff_dst_md = diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format); GET_TENSOR_FORMAT(src_mkl_data_format));
} else if (!dnn_shape_src.IsMklTensor() && } else if (!dnn_shape_src.IsMklTensor() &&
dnn_shape_diff_dst.IsMklTensor()) { dnn_shape_diff_dst.IsMklTensor()) {
diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
memory::format diff_dst_mkl_data_format = MKL_TENSOR_FORMAT diff_dst_mkl_data_format =
dnn_shape_diff_dst.GetTfDataFormat(); dnn_shape_diff_dst.GetTfDataFormat();
auto diff_dst_tf_data_format = auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format); MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
@ -637,8 +697,8 @@ class MklReluGradOpBase : public OpKernel {
diff_dst_tf_data_format) diff_dst_tf_data_format)
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(), : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
diff_dst_tf_data_format); diff_dst_tf_data_format);
src_md = src_md = memory::desc(src_dims, MklDnnType<T>(),
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format); GET_TENSOR_FORMAT(diff_dst_mkl_data_format));
} else { } else {
src_md = dnn_shape_src.GetMklLayout(); src_md = dnn_shape_src.GetMklLayout();
diff_dst_md = dnn_shape_diff_dst.GetMklLayout(); diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
@ -649,7 +709,7 @@ class MklReluGradOpBase : public OpKernel {
// format. So we set common memory descriptor in MKL format, if any of the // format. So we set common memory descriptor in MKL format, if any of the
// inputs are in MKL format. Let's get memory descriptor that we will use // inputs are in MKL format. Let's get memory descriptor that we will use
// for both the inputs. // for both the inputs.
memory::desc common_md({}, memory::data_undef, memory::format_undef); memory::desc common_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md; common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
} else { } else {
@ -662,30 +722,32 @@ class MklReluGradOpBase : public OpKernel {
beta_); beta_);
MklEltwiseBwdPrimitive<T>* eltwise_bwd = MklEltwiseBwdPrimitive<T>* eltwise_bwd =
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams); MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd(); auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
// check whether need reorder for src / diff_dst // check whether need reorder for src / diff_dst
const T* src_data = src_tensor.flat<T>().data(); const T* src_data = src_tensor.flat<T>().data();
if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) { if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
src.SetUsrMem(src_md, &src_tensor); src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem( src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->diff_src_primitive_desc()); eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
src_data = const_cast<T*>( src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle())); reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
} }
const T* diff_dst_data = diff_dst_tensor.flat<T>().data(); const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) { if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
eltwise_bwd)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem( diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->diff_src_primitive_desc()); eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
diff_dst_data = const_cast<T*>( diff_dst_data = const_cast<T*>(
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle())); reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
} }
// allocate diff_src tensor // allocate diff_src tensor
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) { if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc(); auto diff_src_pd = eltwise_bwd_pd->PRIMITIVE_DESC_DIFF_SRC;
dnn_shape_diff_src.SetMklTensor(true); dnn_shape_diff_src.SetMklTensor(true);
dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>()); dnn_shape_diff_src.SetElemType(MklDnnType<T>());
@ -726,8 +788,8 @@ class MklReluGradOpBase : public OpKernel {
} }
private: private:
engine cpu_engine = engine(engine::cpu, 0); engine cpu_engine = engine(ENGINE_CPU, 0);
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd; std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
protected: protected:
float alpha_; float alpha_;
@ -735,12 +797,13 @@ class MklReluGradOpBase : public OpKernel {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> { class MklReluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
public: public:
~MklReluOp() {} ~MklReluOp() {}
explicit MklReluOp(OpKernelConstruction* context) explicit MklReluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {} : MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor const size_t src_index = 0; // index of src input tensor
@ -764,12 +827,14 @@ class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> { class MklReluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
public: public:
~MklReluGradOp() {} ~MklReluGradOp() {}
explicit MklReluGradOp(OpKernelConstruction* context) explicit MklReluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {} : MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t diff_dst_index = 0; // index of diff_dst input tensor
@ -799,12 +864,12 @@ class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> { class MklEluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_elu> {
public: public:
~MklEluOp() {} ~MklEluOp() {}
explicit MklEluOp(OpKernelConstruction* context) explicit MklEluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {} : MklReluOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f, 0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor const size_t src_index = 0; // index of src input tensor
@ -832,12 +897,14 @@ class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> { class MklEluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu> {
public: public:
~MklEluGradOp() {} ~MklEluGradOp() {}
explicit MklEluGradOp(OpKernelConstruction* context) explicit MklEluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {} : MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f,
0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t diff_dst_index = 0; // index of diff_dst input tensor
@ -872,12 +939,13 @@ class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> { class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
public: public:
~MklTanhOp() {} ~MklTanhOp() {}
explicit MklTanhOp(OpKernelConstruction* context) explicit MklTanhOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {} : MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f, 0.0f) {
}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor const size_t src_index = 0; // index of src input tensor
@ -904,12 +972,14 @@ class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> { class MklTanhGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh> {
public: public:
~MklTanhGradOp() {} ~MklTanhGradOp() {}
explicit MklTanhGradOp(OpKernelConstruction* context) explicit MklTanhGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {} : MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f,
0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor const size_t diff_dst_index = 0; // index of diff_dst input tensor
@ -943,12 +1013,13 @@ class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
#define RELU6_UPPER_BOUND 6.0f #define RELU6_UPPER_BOUND 6.0f
template <typename Device, typename T> template <typename Device, typename T>
class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> { class MklRelu6Op
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
public: public:
~MklRelu6Op() {} ~MklRelu6Op() {}
explicit MklRelu6Op(OpKernelConstruction* context) explicit MklRelu6Op(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_bounded_relu>( : MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
context, RELU6_UPPER_BOUND, 0.0f) {} context, RELU6_UPPER_BOUND, 0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
@ -973,12 +1044,12 @@ class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
template <typename Device, typename T> template <typename Device, typename T>
class MklRelu6GradOp class MklRelu6GradOp
: public MklReluGradOpBase<Device, T, eltwise_bounded_relu> { : public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
public: public:
~MklRelu6GradOp() {} ~MklRelu6GradOp() {}
explicit MklRelu6GradOp(OpKernelConstruction* context) explicit MklRelu6GradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_bounded_relu>( : MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
context, RELU6_UPPER_BOUND, 0.0f) {} context, RELU6_UPPER_BOUND, 0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) { virtual void Compute_Scalar(OpKernelContext* context) {
@ -1007,12 +1078,13 @@ class MklRelu6GradOp
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> { class MklLeakyReluOp
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
public: public:
~MklLeakyReluOp() {} ~MklLeakyReluOp() {}
explicit MklLeakyReluOp(OpKernelConstruction* context) explicit MklLeakyReluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) { : MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
float alpha; float alpha;
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha)); OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
OP_REQUIRES( OP_REQUIRES(
@ -1044,12 +1116,14 @@ class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
}; };
template <typename Device, typename T> template <typename Device, typename T>
class MklLeakyReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> { class MklLeakyReluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
public: public:
~MklLeakyReluGradOp() {} ~MklLeakyReluGradOp() {}
explicit MklLeakyReluGradOp(OpKernelConstruction* context) explicit MklLeakyReluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) { : MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
0.0f) {
float alpha; float alpha;
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha)); OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
OP_REQUIRES( OP_REQUIRES(

View File

@ -77,6 +77,7 @@ namespace tensorflow {
memory::desc({dims}, MklDnnType<type>(), fm) memory::desc({dims}, MklDnnType<type>(), fm)
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine #define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
#define MEMORY_PRIMITIVE_DESC memory::desc #define MEMORY_PRIMITIVE_DESC memory::desc
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) MEMORY_PRIMITIVE_DESC(md)
#define MKL_FMT_TAG mkl_fmt_tag #define MKL_FMT_TAG mkl_fmt_tag
#define MKL_TENSOR_FORMAT MklTensorFormat #define MKL_TENSOR_FORMAT MklTensorFormat
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED #define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
@ -170,6 +171,8 @@ namespace tensorflow {
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine) memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd #define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc #define MEMORY_PRIMITIVE_DESC memory::primitive_desc
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
MEMORY_PRIMITIVE_DESC(md, engine)
#define MKL_FMT_TAG tf_fmt #define MKL_FMT_TAG tf_fmt
#define MKL_TENSOR_FORMAT memory::format #define MKL_TENSOR_FORMAT memory::format
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked #define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked