Merge pull request #36469 from Intel-tensorflow:mabuzain/relu_1.x
PiperOrigin-RevId: 294344071 Change-Id: I739ce33c0f7617e53bbc95d68f883e422e497a41
This commit is contained in:
commit
61d19090c8
@ -654,6 +654,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
|
||||
CopyAttrsAll, QuantizeOpRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
@ -666,10 +667,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
rinfo_.push_back(
|
||||
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back(
|
||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
// Disable these two MKL operators for now due to some test failures caused
|
||||
// by these two ops
|
||||
/*
|
||||
@ -682,7 +683,6 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
||||
CopyAttrsAll, AlwaysRewrite,
|
||||
kRewriteForLayoutPropagation});
|
||||
*/
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
rinfo_.push_back(
|
||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
@ -23,24 +25,24 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
using mkldnn::algorithm;
|
||||
using mkldnn::eltwise_bounded_relu;
|
||||
using mkldnn::eltwise_elu;
|
||||
using mkldnn::eltwise_forward;
|
||||
using mkldnn::eltwise_relu;
|
||||
using mkldnn::eltwise_tanh;
|
||||
using mkldnn::memory;
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::stream;
|
||||
|
||||
using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc;
|
||||
using EltwiseBwdPd = mkldnn::eltwise_backward::primitive_desc;
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
template <typename T>
|
||||
class MklEltwiseFwdParams {
|
||||
public:
|
||||
memory::dims src_dims; // check if this is needed
|
||||
memory::dims src_dims;
|
||||
memory::desc src_md;
|
||||
algorithm alg_kind;
|
||||
float alpha;
|
||||
@ -59,11 +61,12 @@ template <typename T>
|
||||
class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
// store expected format
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.src_fmt =
|
||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
||||
#endif
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
|
||||
// create eltwise primitive
|
||||
if (context_.eltwise_fwd == nullptr) {
|
||||
@ -80,24 +83,38 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
context_.src_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
|
||||
// after execution, set data handle back
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
||||
context_.fwd_primitives_args.at(i));
|
||||
}
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif
|
||||
|
||||
// After execution, set data handle back.
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() {
|
||||
return context_.fwd_pd;
|
||||
}
|
||||
std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// In MKL-DNN v1.x, memory format tags only provide a partial description
|
||||
// of the memory layout. Hence, these functions are disabled for v1.x.
|
||||
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
|
||||
#endif
|
||||
|
||||
private:
|
||||
// Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
|
||||
struct EltwiseFwdContext {
|
||||
// expected memory format for this primitive instance
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
// Expected memory format for this primitive instance
|
||||
mkldnn::memory::format src_fmt;
|
||||
#endif
|
||||
|
||||
// MKLDNN memory
|
||||
std::shared_ptr<memory> src_mem;
|
||||
@ -105,14 +122,14 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
|
||||
// desc & prmitive desc
|
||||
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
|
||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
|
||||
std::shared_ptr<EltwiseFwdPd> fwd_pd;
|
||||
|
||||
// memory desc
|
||||
std::shared_ptr<memory::desc> src_md;
|
||||
std::shared_ptr<memory::desc> dst_md;
|
||||
|
||||
// memory primitive desc
|
||||
std::shared_ptr<memory::primitive_desc> src_mpd;
|
||||
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
|
||||
|
||||
// Eltwise primitive
|
||||
std::shared_ptr<mkldnn::primitive> eltwise_fwd;
|
||||
@ -120,8 +137,15 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
std::shared_ptr<stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||
#endif
|
||||
|
||||
EltwiseFwdContext()
|
||||
: src_fmt(memory::format::any),
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_fmt(memory::format::any),
|
||||
#endif
|
||||
src_mem(nullptr),
|
||||
dst_mem(nullptr),
|
||||
fwd_desc(nullptr),
|
||||
@ -130,31 +154,43 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||
dst_md(nullptr),
|
||||
src_mpd(nullptr),
|
||||
eltwise_fwd(nullptr),
|
||||
fwd_stream(nullptr) {}
|
||||
fwd_stream(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
// Eltwise forward primitive setup
|
||||
void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||
// create memory descriptors for eltwise data with specified format
|
||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||
context_.src_mpd.reset(
|
||||
new memory::primitive_desc(*context_.src_md, cpu_engine_));
|
||||
|
||||
// create a eltwise
|
||||
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
||||
context_.src_mpd.reset(
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
|
||||
// Create an eltwise forward descriptor and primitive descriptor
|
||||
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
|
||||
fwdParams.alpha, fwdParams.beta));
|
||||
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
auto fwd_pd = context_.fwd_pd.get();
|
||||
|
||||
// create memory primitive based on dummy data
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
|
||||
cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_DST,
|
||||
cpu_engine_, DummyData));
|
||||
// Create eltwise primitive and add it to net
|
||||
context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
|
||||
context_.fwd_primitives_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
#else
|
||||
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||
|
||||
// create eltwise primitive and add it to net
|
||||
context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(
|
||||
context_.eltwise_fwd.reset(new eltwise_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
|
||||
#endif
|
||||
|
||||
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
|
||||
}
|
||||
@ -170,18 +206,16 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
const MklEltwiseFwdParams<T>& fwdParams) {
|
||||
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
|
||||
|
||||
auto src_fmt =
|
||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
||||
|
||||
// Get a eltwise fwd primitive from the cached pool
|
||||
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
|
||||
src_fmt));
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
|
||||
fwdParams));
|
||||
if (eltwise_forward == nullptr) {
|
||||
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
|
||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
||||
fwdParams, src_fmt, eltwise_forward);
|
||||
fwdParams, eltwise_forward);
|
||||
}
|
||||
|
||||
return eltwise_forward;
|
||||
}
|
||||
|
||||
@ -194,8 +228,7 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
MklEltwiseFwdPrimitiveFactory() {}
|
||||
~MklEltwiseFwdPrimitiveFactory() {}
|
||||
|
||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
|
||||
memory::format src_fmt) {
|
||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||
string prefix = "eltwise_fwd";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
@ -203,19 +236,20 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
||||
key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
||||
memory::format src_fmt) {
|
||||
string key = CreateKey(fwdParams, src_fmt);
|
||||
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||
string key = CreateKey(fwdParams);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
||||
memory::format src_fmt, MklPrimitive* op) {
|
||||
string key = CreateKey(fwdParams, src_fmt);
|
||||
MklPrimitive* op) {
|
||||
string key = CreateKey(fwdParams);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
};
|
||||
@ -243,12 +277,14 @@ template <typename T>
|
||||
class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
context_.src_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
context_.diff_dst_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
context_.bwd_stream.reset(new stream(stream::kind::eager));
|
||||
#endif
|
||||
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
|
||||
// create eltwise primitive
|
||||
if (context_.eltwise_bwd == nullptr) {
|
||||
Setup(bwdParams);
|
||||
@ -267,7 +303,17 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
context_.diff_dst_mem->set_data_handle(
|
||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.bwd_primitives.size(),
|
||||
context_.bwd_primitives_args.size());
|
||||
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
|
||||
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
|
||||
context_.bwd_primitives_args.at(i));
|
||||
}
|
||||
#else
|
||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
// after execution, set data handle back
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
@ -275,52 +321,61 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
context_.diff_src_mem->set_data_handle(DummyData);
|
||||
}
|
||||
|
||||
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() {
|
||||
return context_.bwd_pd;
|
||||
}
|
||||
std::shared_ptr<EltwiseBwdPd> GetEltwiseBwdPd() { return context_.bwd_pd; }
|
||||
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
|
||||
|
||||
memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
|
||||
private:
|
||||
// Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
|
||||
struct EltwiseBwdContext {
|
||||
// expected memory format for this primitive instance
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
memory::format src_fmt;
|
||||
memory::format diff_dst_fmt;
|
||||
#endif
|
||||
|
||||
// MKLDNN memory
|
||||
std::shared_ptr<memory> src_mem;
|
||||
std::shared_ptr<memory> diff_dst_mem;
|
||||
std::shared_ptr<memory> diff_src_mem;
|
||||
|
||||
// desc & prmitive desc
|
||||
// Backward Eltwise descriptor.
|
||||
std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
|
||||
|
||||
// memory desc
|
||||
// Memory descriptors.
|
||||
std::shared_ptr<memory::desc> src_md;
|
||||
std::shared_ptr<memory::desc> diff_dst_md;
|
||||
std::shared_ptr<memory::desc> common_md;
|
||||
|
||||
// memory primitive desc
|
||||
std::shared_ptr<memory::primitive_desc> src_mpd;
|
||||
std::shared_ptr<memory::primitive_desc> diff_dst_mpd;
|
||||
// Memory primitive descriptor.
|
||||
// TODO(gzmkl): for MKL-DNN 1.0, src_mpd is same as src_md
|
||||
// So it should be removed once MKL-DNN 0.x is cleaned.
|
||||
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
|
||||
std::shared_ptr<MEMORY_PRIMITIVE_DESC> diff_dst_mpd;
|
||||
|
||||
// fwd primitive desc
|
||||
// Forward and backward descriptors and primitive descriptors.
|
||||
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
|
||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
|
||||
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd;
|
||||
std::shared_ptr<EltwiseFwdPd> fwd_pd;
|
||||
std::shared_ptr<EltwiseBwdPd> bwd_pd;
|
||||
|
||||
// Eltwise primitive
|
||||
// Eltwise primitive.
|
||||
std::shared_ptr<mkldnn::primitive> eltwise_bwd;
|
||||
|
||||
std::shared_ptr<stream> bwd_stream;
|
||||
std::vector<mkldnn::primitive> bwd_primitives;
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<MemoryArgsMap> bwd_primitives_args;
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
EltwiseBwdContext()
|
||||
: src_fmt(memory::format::any),
|
||||
:
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
src_fmt(memory::format::any),
|
||||
diff_dst_fmt(memory::format::any),
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
src_mem(nullptr),
|
||||
diff_dst_mem(nullptr),
|
||||
diff_src_mem(nullptr),
|
||||
@ -333,42 +388,58 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||
fwd_pd(nullptr),
|
||||
bwd_pd(nullptr),
|
||||
eltwise_bwd(nullptr),
|
||||
bwd_stream(nullptr) {}
|
||||
bwd_stream(nullptr) {
|
||||
}
|
||||
};
|
||||
|
||||
// Eltwise backward primitive setup
|
||||
void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||
// create memory descriptors for eltwise data w/ no specified format
|
||||
// Create memory descriptors for eltwise data w/ no specified format
|
||||
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||
|
||||
context_.src_mpd.reset(
|
||||
new memory::primitive_desc(*context_.src_md, cpu_engine_));
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||
context_.diff_dst_mpd.reset(
|
||||
new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_));
|
||||
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
|
||||
|
||||
// create forward eltwise primitive
|
||||
// Create forward eltwise primitive.
|
||||
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
||||
prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
|
||||
bwdParams.alpha, bwdParams.beta));
|
||||
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||
context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
|
||||
bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
|
||||
bwdParams.alpha, bwdParams.beta));
|
||||
context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc(
|
||||
*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
|
||||
context_.bwd_pd.reset(
|
||||
new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
|
||||
|
||||
// create memory primitive based on dummy data
|
||||
auto bwd_pd = context_.bwd_pd.get();
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Create memory primitive based on dummy data.
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(bwd_pd->PRIMITIVE_DESC_SRC,
|
||||
cpu_engine_, DummyData));
|
||||
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
bwd_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData));
|
||||
context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||
bwd_pd->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_, DummyData));
|
||||
// Create eltwise primitive and add it to net.
|
||||
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd));
|
||||
context_.bwd_primitives_args.push_back(
|
||||
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
|
||||
{ MKLDNN_ARG_DIFF_SRC,
|
||||
*context_.diff_src_mem }});
|
||||
#else
|
||||
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
||||
context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
|
||||
context_.diff_src_mem.reset(new memory(
|
||||
context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
|
||||
|
||||
// create eltwise primitive and add it to net
|
||||
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
|
||||
*context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
|
||||
*context_.diff_src_mem));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
context_.bwd_primitives.push_back(*context_.eltwise_bwd);
|
||||
}
|
||||
@ -388,20 +459,15 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
const MklEltwiseBwdParams<T>& bwdParams) {
|
||||
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
|
||||
|
||||
auto src_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
auto diff_dst_fmt =
|
||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||
|
||||
// try to find a suitable one in pool
|
||||
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
||||
bwdParams, src_fmt, diff_dst_fmt));
|
||||
bwdParams));
|
||||
|
||||
if (eltwise_backward == nullptr) {
|
||||
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
|
||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
||||
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
|
||||
bwdParams, eltwise_backward);
|
||||
}
|
||||
return eltwise_backward;
|
||||
}
|
||||
@ -412,9 +478,7 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
}
|
||||
|
||||
private:
|
||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
|
||||
const memory::format& src_fmt,
|
||||
const memory::format& diff_dst_fmt) {
|
||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||
string prefix = "eltwise_bwd";
|
||||
FactoryKeyCreator key_creator;
|
||||
key_creator.AddAsKey(prefix);
|
||||
@ -422,22 +486,20 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
||||
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
||||
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
||||
key_creator.AddAsKey(static_cast<int>(src_fmt));
|
||||
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
|
||||
#ifndef ENABLE_MKLDNN_V1
|
||||
key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
|
||||
#endif // !ENABLE_MKLDNN_V1
|
||||
return key_creator.GetKey();
|
||||
}
|
||||
|
||||
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
||||
const memory::format& src_fmt,
|
||||
const memory::format& diff_dst_fmt) {
|
||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
||||
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||
string key = CreateKey(bwdParams);
|
||||
return this->GetOp(key);
|
||||
}
|
||||
|
||||
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
||||
const memory::format& src_fmt,
|
||||
const memory::format& diff_dst_fmt, MklPrimitive* op) {
|
||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
||||
MklPrimitive* op) {
|
||||
string key = CreateKey(bwdParams);
|
||||
this->SetOp(key, op);
|
||||
}
|
||||
};
|
||||
@ -481,7 +543,7 @@ class MklReluOpBase : public OpKernel {
|
||||
// Set DNN primitive - src
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
memory::dims src_dims;
|
||||
memory::desc src_md({}, memory::data_undef, memory::format_undef);
|
||||
memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||
if (dnn_shape_src.IsMklTensor()) {
|
||||
src_md = dnn_shape_src.GetMklLayout();
|
||||
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
|
||||
@ -492,31 +554,29 @@ class MklReluOpBase : public OpKernel {
|
||||
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
|
||||
}
|
||||
|
||||
// get a eltwise fwd from primitive pool
|
||||
// Try to get an eltwise forward primitive from caching pool
|
||||
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
|
||||
beta_);
|
||||
|
||||
MklEltwiseFwdPrimitive<T>* eltwise_fwd =
|
||||
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||
|
||||
// prepare for execuation
|
||||
auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
|
||||
|
||||
// Check if src needs to be reordered
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
// check wehther src need to reorder
|
||||
if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
auto src_target_pd = memory::primitive_desc(
|
||||
{{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()},
|
||||
cpu_engine);
|
||||
src.CheckReorderToOpMem(src_target_pd);
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
|
||||
src_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||
}
|
||||
|
||||
// allocate dst tensor, always set it as MKL-DNN layout
|
||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> eltwise_fwd_pd =
|
||||
eltwise_fwd->GetEltwiseFwdPd();
|
||||
// Allocate dst tensor, always set it as MKL-DNN layout
|
||||
if (dnn_shape_src.IsMklTensor()) {
|
||||
dnn_shape_dst.SetMklTensor(true);
|
||||
auto dst_pd = eltwise_fwd_pd->dst_primitive_desc();
|
||||
auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST;
|
||||
dnn_shape_dst.SetMklLayout(&dst_pd);
|
||||
dnn_shape_dst.SetElemType(MklDnnType<T>());
|
||||
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
|
||||
@ -549,8 +609,8 @@ class MklReluOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
engine cpu_engine = engine(engine::cpu, 0);
|
||||
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd;
|
||||
engine cpu_engine = engine(ENGINE_CPU, 0);
|
||||
std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
|
||||
|
||||
protected:
|
||||
float alpha_;
|
||||
@ -604,8 +664,8 @@ class MklReluGradOpBase : public OpKernel {
|
||||
|
||||
// get a eltwise bwd from primitive pool
|
||||
memory::dims src_dims = {};
|
||||
memory::desc src_md({}, memory::data_undef, memory::format_undef);
|
||||
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
|
||||
memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||
memory::desc diff_dst_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
|
||||
src_dims = TFShapeToMklDnnDims(src_tensor.shape());
|
||||
auto src_strides = CalculateTFStrides(src_dims);
|
||||
@ -616,18 +676,18 @@ class MklReluGradOpBase : public OpKernel {
|
||||
src_md = dnn_shape_src.GetMklLayout();
|
||||
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
|
||||
|
||||
memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
|
||||
MKL_TENSOR_FORMAT src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
|
||||
auto src_tf_data_format =
|
||||
MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
|
||||
auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
|
||||
src_tf_data_format);
|
||||
diff_dst_md =
|
||||
memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format);
|
||||
diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
|
||||
GET_TENSOR_FORMAT(src_mkl_data_format));
|
||||
} else if (!dnn_shape_src.IsMklTensor() &&
|
||||
dnn_shape_diff_dst.IsMklTensor()) {
|
||||
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
|
||||
|
||||
memory::format diff_dst_mkl_data_format =
|
||||
MKL_TENSOR_FORMAT diff_dst_mkl_data_format =
|
||||
dnn_shape_diff_dst.GetTfDataFormat();
|
||||
auto diff_dst_tf_data_format =
|
||||
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
|
||||
@ -637,8 +697,8 @@ class MklReluGradOpBase : public OpKernel {
|
||||
diff_dst_tf_data_format)
|
||||
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
|
||||
diff_dst_tf_data_format);
|
||||
src_md =
|
||||
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
|
||||
src_md = memory::desc(src_dims, MklDnnType<T>(),
|
||||
GET_TENSOR_FORMAT(diff_dst_mkl_data_format));
|
||||
} else {
|
||||
src_md = dnn_shape_src.GetMklLayout();
|
||||
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
|
||||
@ -649,7 +709,7 @@ class MklReluGradOpBase : public OpKernel {
|
||||
// format. So we set common memory descriptor in MKL format, if any of the
|
||||
// inputs are in MKL format. Let's get memory descriptor that we will use
|
||||
// for both the inputs.
|
||||
memory::desc common_md({}, memory::data_undef, memory::format_undef);
|
||||
memory::desc common_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
|
||||
common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
|
||||
} else {
|
||||
@ -662,30 +722,32 @@ class MklReluGradOpBase : public OpKernel {
|
||||
beta_);
|
||||
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
|
||||
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||
|
||||
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
|
||||
|
||||
// check whether need reorder for src / diff_dst
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) {
|
||||
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
|
||||
src.SetUsrMem(src_md, &src_tensor);
|
||||
src.CheckReorderToOpMem(
|
||||
eltwise_bwd_pd.get()->diff_src_primitive_desc());
|
||||
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
|
||||
src_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||
}
|
||||
|
||||
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
|
||||
if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) {
|
||||
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
|
||||
eltwise_bwd)) {
|
||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||
diff_dst.CheckReorderToOpMem(
|
||||
eltwise_bwd_pd.get()->diff_src_primitive_desc());
|
||||
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
|
||||
diff_dst_data = const_cast<T*>(
|
||||
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
|
||||
}
|
||||
|
||||
// allocate diff_src tensor
|
||||
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
|
||||
auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
|
||||
auto diff_src_pd = eltwise_bwd_pd->PRIMITIVE_DESC_DIFF_SRC;
|
||||
dnn_shape_diff_src.SetMklTensor(true);
|
||||
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
|
||||
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
|
||||
@ -726,8 +788,8 @@ class MklReluGradOpBase : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
engine cpu_engine = engine(engine::cpu, 0);
|
||||
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd;
|
||||
engine cpu_engine = engine(ENGINE_CPU, 0);
|
||||
std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
|
||||
|
||||
protected:
|
||||
float alpha_;
|
||||
@ -735,12 +797,13 @@ class MklReluGradOpBase : public OpKernel {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
||||
class MklReluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||
public:
|
||||
~MklReluOp() {}
|
||||
|
||||
explicit MklReluOp(OpKernelConstruction* context)
|
||||
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {}
|
||||
: MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
|
||||
}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t src_index = 0; // index of src input tensor
|
||||
@ -764,12 +827,14 @@ class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
|
||||
class MklReluGradOp
|
||||
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||
public:
|
||||
~MklReluGradOp() {}
|
||||
|
||||
explicit MklReluGradOp(OpKernelConstruction* context)
|
||||
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {}
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
|
||||
0.0f) {}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||
@ -799,12 +864,12 @@ class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
|
||||
class MklEluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_elu> {
|
||||
public:
|
||||
~MklEluOp() {}
|
||||
|
||||
explicit MklEluOp(OpKernelConstruction* context)
|
||||
: MklReluOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {}
|
||||
: MklReluOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f, 0.0f) {}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t src_index = 0; // index of src input tensor
|
||||
@ -832,12 +897,14 @@ class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
|
||||
class MklEluGradOp
|
||||
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu> {
|
||||
public:
|
||||
~MklEluGradOp() {}
|
||||
|
||||
explicit MklEluGradOp(OpKernelConstruction* context)
|
||||
: MklReluGradOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {}
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f,
|
||||
0.0f) {}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||
@ -872,12 +939,13 @@ class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
|
||||
class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||
public:
|
||||
~MklTanhOp() {}
|
||||
|
||||
explicit MklTanhOp(OpKernelConstruction* context)
|
||||
: MklReluOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {}
|
||||
: MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f, 0.0f) {
|
||||
}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t src_index = 0; // index of src input tensor
|
||||
@ -904,12 +972,14 @@ class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
|
||||
class MklTanhGradOp
|
||||
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||
public:
|
||||
~MklTanhGradOp() {}
|
||||
|
||||
explicit MklTanhGradOp(OpKernelConstruction* context)
|
||||
: MklReluGradOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {}
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f,
|
||||
0.0f) {}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||
@ -943,12 +1013,13 @@ class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
|
||||
|
||||
#define RELU6_UPPER_BOUND 6.0f
|
||||
template <typename Device, typename T>
|
||||
class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
|
||||
class MklRelu6Op
|
||||
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
|
||||
public:
|
||||
~MklRelu6Op() {}
|
||||
|
||||
explicit MklRelu6Op(OpKernelConstruction* context)
|
||||
: MklReluOpBase<Device, T, eltwise_bounded_relu>(
|
||||
: MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
|
||||
context, RELU6_UPPER_BOUND, 0.0f) {}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
@ -973,12 +1044,12 @@ class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklRelu6GradOp
|
||||
: public MklReluGradOpBase<Device, T, eltwise_bounded_relu> {
|
||||
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
|
||||
public:
|
||||
~MklRelu6GradOp() {}
|
||||
|
||||
explicit MklRelu6GradOp(OpKernelConstruction* context)
|
||||
: MklReluGradOpBase<Device, T, eltwise_bounded_relu>(
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
|
||||
context, RELU6_UPPER_BOUND, 0.0f) {}
|
||||
|
||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||
@ -1007,12 +1078,13 @@ class MklRelu6GradOp
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
||||
class MklLeakyReluOp
|
||||
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||
public:
|
||||
~MklLeakyReluOp() {}
|
||||
|
||||
explicit MklLeakyReluOp(OpKernelConstruction* context)
|
||||
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
|
||||
: MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
|
||||
float alpha;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
|
||||
OP_REQUIRES(
|
||||
@ -1044,12 +1116,14 @@ class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklLeakyReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
|
||||
class MklLeakyReluGradOp
|
||||
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||
public:
|
||||
~MklLeakyReluGradOp() {}
|
||||
|
||||
explicit MklLeakyReluGradOp(OpKernelConstruction* context)
|
||||
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
|
||||
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
|
||||
0.0f) {
|
||||
float alpha;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
|
||||
OP_REQUIRES(
|
||||
|
@ -77,6 +77,7 @@ namespace tensorflow {
|
||||
memory::desc({dims}, MklDnnType<type>(), fm)
|
||||
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
||||
#define MEMORY_PRIMITIVE_DESC memory::desc
|
||||
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) MEMORY_PRIMITIVE_DESC(md)
|
||||
#define MKL_FMT_TAG mkl_fmt_tag
|
||||
#define MKL_TENSOR_FORMAT MklTensorFormat
|
||||
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
||||
@ -170,6 +171,8 @@ namespace tensorflow {
|
||||
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
||||
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
||||
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
|
||||
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
|
||||
MEMORY_PRIMITIVE_DESC(md, engine)
|
||||
#define MKL_FMT_TAG tf_fmt
|
||||
#define MKL_TENSOR_FORMAT memory::format
|
||||
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
||||
|
Loading…
Reference in New Issue
Block a user