Merge pull request #36469 from Intel-tensorflow:mabuzain/relu_1.x

PiperOrigin-RevId: 294344071
Change-Id: I739ce33c0f7617e53bbc95d68f883e422e497a41
This commit is contained in:
TensorFlower Gardener 2020-02-10 17:42:22 -08:00
commit 61d19090c8
3 changed files with 222 additions and 145 deletions

View File

@ -654,6 +654,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
CopyAttrsAll, QuantizeOpRewrite,
kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
@ -666,10 +667,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back(
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back(
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
#endif // !ENABLE_MKLDNN_V1
// Disable these two MKL operators for now due to some test failures caused
// by these two ops
/*
@ -682,7 +683,6 @@ rinfo_.push_back({csinfo_.tanh_grad,
CopyAttrsAll, AlwaysRewrite,
kRewriteForLayoutPropagation});
*/
#ifndef ENABLE_MKLDNN_V1
rinfo_.push_back(
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});

View File

@ -16,6 +16,8 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL
#include <unordered_map>
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
@ -23,24 +25,24 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
using mkldnn::algorithm;
using mkldnn::eltwise_bounded_relu;
using mkldnn::eltwise_elu;
using mkldnn::eltwise_forward;
using mkldnn::eltwise_relu;
using mkldnn::eltwise_tanh;
using mkldnn::memory;
using mkldnn::prop_kind;
using mkldnn::stream;
using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc;
using EltwiseBwdPd = mkldnn::eltwise_backward::primitive_desc;
namespace tensorflow {
template <typename T>
class MklEltwiseFwdParams {
public:
memory::dims src_dims; // check if this is needed
memory::dims src_dims;
memory::desc src_md;
algorithm alg_kind;
float alpha;
@ -59,11 +61,12 @@ template <typename T>
class MklEltwiseFwdPrimitive : public MklPrimitive {
public:
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
: cpu_engine_(engine::cpu, 0) {
// store expected format
: cpu_engine_(ENGINE_CPU, 0) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
context_.fwd_stream.reset(new stream(stream::kind::eager));
#endif
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
// create eltwise primitive
if (context_.eltwise_fwd == nullptr) {
@ -80,24 +83,38 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
context_.src_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
context_.fwd_stream->submit(context_.fwd_primitives);
// after execution, set data handle back
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_primitives_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.fwd_primitives_args.at(i));
}
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif
// After execution, set data handle back.
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
}
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() {
return context_.fwd_pd;
}
std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
#ifndef ENABLE_MKLDNN_V1
// In MKL-DNN v1.x, memory format tags only provide a partial description
// of the memory layout. Hence, these functions are disabled for v1.x.
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
#endif
private:
// Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
struct EltwiseFwdContext {
// expected memory format for this primitive instance
#ifndef ENABLE_MKLDNN_V1
// Expected memory format for this primitive instance
mkldnn::memory::format src_fmt;
#endif
// MKLDNN memory
std::shared_ptr<memory> src_mem;
@ -105,14 +122,14 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
// desc & prmitive desc
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
std::shared_ptr<EltwiseFwdPd> fwd_pd;
// memory desc
std::shared_ptr<memory::desc> src_md;
std::shared_ptr<memory::desc> dst_md;
// memory primitive desc
std::shared_ptr<memory::primitive_desc> src_mpd;
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
// Eltwise primitive
std::shared_ptr<mkldnn::primitive> eltwise_fwd;
@ -120,8 +137,15 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
std::shared_ptr<stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
#endif
EltwiseFwdContext()
: src_fmt(memory::format::any),
:
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
#endif
src_mem(nullptr),
dst_mem(nullptr),
fwd_desc(nullptr),
@ -130,31 +154,43 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
dst_md(nullptr),
src_mpd(nullptr),
eltwise_fwd(nullptr),
fwd_stream(nullptr) {}
fwd_stream(nullptr) {
}
};
// Eltwise forward primitive setup
void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
// create memory descriptors for eltwise data with specified format
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
context_.src_mpd.reset(
new memory::primitive_desc(*context_.src_md, cpu_engine_));
// create a eltwise
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
context_.src_mpd.reset(
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
// Create an eltwise forward descriptor and primitive descriptor
context_.fwd_desc.reset(new eltwise_forward::desc(
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
fwdParams.alpha, fwdParams.beta));
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
auto fwd_pd = context_.fwd_pd.get();
// create memory primitive based on dummy data
#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_DST,
cpu_engine_, DummyData));
// Create eltwise primitive and add it to net
context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
context_.fwd_primitives_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
#else
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
// create eltwise primitive and add it to net
context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(
context_.eltwise_fwd.reset(new eltwise_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
#endif
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
}
@ -170,18 +206,16 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklEltwiseFwdParams<T>& fwdParams) {
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
auto src_fmt =
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
// Get a eltwise fwd primitive from the cached pool
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
src_fmt));
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
fwdParams));
if (eltwise_forward == nullptr) {
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
fwdParams, src_fmt, eltwise_forward);
fwdParams, eltwise_forward);
}
return eltwise_forward;
}
@ -194,8 +228,7 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
MklEltwiseFwdPrimitiveFactory() {}
~MklEltwiseFwdPrimitiveFactory() {}
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt) {
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
string prefix = "eltwise_fwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
@ -203,19 +236,20 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
#endif // !ENABLE_MKLDNN_V1
return key_creator.GetKey();
}
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt) {
string key = CreateKey(fwdParams, src_fmt);
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
string key = CreateKey(fwdParams);
return this->GetOp(key);
}
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
memory::format src_fmt, MklPrimitive* op) {
string key = CreateKey(fwdParams, src_fmt);
MklPrimitive* op) {
string key = CreateKey(fwdParams);
this->SetOp(key, op);
}
};
@ -243,12 +277,14 @@ template <typename T>
class MklEltwiseBwdPrimitive : public MklPrimitive {
public:
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
: cpu_engine_(engine::cpu, 0) {
: cpu_engine_(ENGINE_CPU, 0) {
#ifndef ENABLE_MKLDNN_V1
context_.src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
context_.diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
context_.bwd_stream.reset(new stream(stream::kind::eager));
#endif
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
// create eltwise primitive
if (context_.eltwise_bwd == nullptr) {
Setup(bwdParams);
@ -267,7 +303,17 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.diff_dst_mem->set_data_handle(
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.bwd_primitives.size(),
context_.bwd_primitives_args.size());
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
context_.bwd_primitives_args.at(i));
}
#else
context_.bwd_stream->submit(context_.bwd_primitives);
#endif // ENABLE_MKLDNN_V1
// after execution, set data handle back
context_.src_mem->set_data_handle(DummyData);
@ -275,52 +321,61 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
context_.diff_src_mem->set_data_handle(DummyData);
}
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() {
return context_.bwd_pd;
}
std::shared_ptr<EltwiseBwdPd> GetEltwiseBwdPd() { return context_.bwd_pd; }
#ifndef ENABLE_MKLDNN_V1
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
#endif // !ENABLE_MKLDNN_V1
private:
// Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
struct EltwiseBwdContext {
// expected memory format for this primitive instance
#ifndef ENABLE_MKLDNN_V1
memory::format src_fmt;
memory::format diff_dst_fmt;
#endif
// MKLDNN memory
std::shared_ptr<memory> src_mem;
std::shared_ptr<memory> diff_dst_mem;
std::shared_ptr<memory> diff_src_mem;
// desc & prmitive desc
// Backward Eltwise descriptor.
std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
// memory desc
// Memory descriptors.
std::shared_ptr<memory::desc> src_md;
std::shared_ptr<memory::desc> diff_dst_md;
std::shared_ptr<memory::desc> common_md;
// memory primitive desc
std::shared_ptr<memory::primitive_desc> src_mpd;
std::shared_ptr<memory::primitive_desc> diff_dst_mpd;
// Memory primitive descriptor.
// TODO(gzmkl): for MKL-DNN 1.0, src_mpd is same as src_md
// So it should be removed once MKL-DNN 0.x is cleaned.
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
std::shared_ptr<MEMORY_PRIMITIVE_DESC> diff_dst_mpd;
// fwd primitive desc
// Forward and backward descriptors and primitive descriptors.
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd;
std::shared_ptr<EltwiseFwdPd> fwd_pd;
std::shared_ptr<EltwiseBwdPd> bwd_pd;
// Eltwise primitive
// Eltwise primitive.
std::shared_ptr<mkldnn::primitive> eltwise_bwd;
std::shared_ptr<stream> bwd_stream;
std::vector<mkldnn::primitive> bwd_primitives;
#ifdef ENABLE_MKLDNN_V1
std::vector<MemoryArgsMap> bwd_primitives_args;
#endif // ENABLE_MKLDNN_V1
EltwiseBwdContext()
: src_fmt(memory::format::any),
:
#ifndef ENABLE_MKLDNN_V1
src_fmt(memory::format::any),
diff_dst_fmt(memory::format::any),
#endif // !ENABLE_MKLDNN_V1
src_mem(nullptr),
diff_dst_mem(nullptr),
diff_src_mem(nullptr),
@ -333,42 +388,58 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
fwd_pd(nullptr),
bwd_pd(nullptr),
eltwise_bwd(nullptr),
bwd_stream(nullptr) {}
bwd_stream(nullptr) {
}
};
// Eltwise backward primitive setup
void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
// create memory descriptors for eltwise data w/ no specified format
// Create memory descriptors for eltwise data w/ no specified format
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
context_.src_mpd.reset(
new memory::primitive_desc(*context_.src_md, cpu_engine_));
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
context_.diff_dst_mpd.reset(
new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_));
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
// create forward eltwise primitive
// Create forward eltwise primitive.
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
bwdParams.alpha, bwdParams.beta));
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
bwdParams.alpha, bwdParams.beta));
context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc(
*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
context_.bwd_pd.reset(
new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
// create memory primitive based on dummy data
auto bwd_pd = context_.bwd_pd.get();
#ifdef ENABLE_MKLDNN_V1
// Create memory primitive based on dummy data.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(bwd_pd->PRIMITIVE_DESC_SRC,
cpu_engine_, DummyData));
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData));
context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR(
bwd_pd->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_, DummyData));
// Create eltwise primitive and add it to net.
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd));
context_.bwd_primitives_args.push_back(
{{MKLDNN_ARG_SRC, *context_.src_mem},
{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
{ MKLDNN_ARG_DIFF_SRC,
*context_.diff_src_mem }});
#else
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
context_.diff_src_mem.reset(new memory(
context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
// create eltwise primitive and add it to net
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
*context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
*context_.diff_src_mem));
#endif // ENABLE_MKLDNN_V1
context_.bwd_primitives.push_back(*context_.eltwise_bwd);
}
@ -388,20 +459,15 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
const MklEltwiseBwdParams<T>& bwdParams) {
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
auto src_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
auto diff_dst_fmt =
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
// try to find a suitable one in pool
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
bwdParams, src_fmt, diff_dst_fmt));
bwdParams));
if (eltwise_backward == nullptr) {
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
bwdParams, eltwise_backward);
}
return eltwise_backward;
}
@ -412,9 +478,7 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
}
private:
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
string prefix = "eltwise_bwd";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
@ -422,22 +486,20 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
key_creator.AddAsKey(static_cast<int>(src_fmt));
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
#ifndef ENABLE_MKLDNN_V1
key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
#endif // !ENABLE_MKLDNN_V1
return key_creator.GetKey();
}
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt) {
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams) {
string key = CreateKey(bwdParams);
return this->GetOp(key);
}
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
const memory::format& src_fmt,
const memory::format& diff_dst_fmt, MklPrimitive* op) {
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
MklPrimitive* op) {
string key = CreateKey(bwdParams);
this->SetOp(key, op);
}
};
@ -481,7 +543,7 @@ class MklReluOpBase : public OpKernel {
// Set DNN primitive - src
MklDnnData<T> src(&cpu_engine);
memory::dims src_dims;
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
@ -492,31 +554,29 @@ class MklReluOpBase : public OpKernel {
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
}
// get a eltwise fwd from primitive pool
// Try to get an eltwise forward primitive from caching pool
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
beta_);
MklEltwiseFwdPrimitive<T>* eltwise_fwd =
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
// prepare for execuation
auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
// Check if src needs to be reordered
const T* src_data = src_tensor.flat<T>().data();
// check wehther src need to reorder
if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
src.SetUsrMem(src_md, &src_tensor);
auto src_target_pd = memory::primitive_desc(
{{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()},
cpu_engine);
src.CheckReorderToOpMem(src_target_pd);
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
// allocate dst tensor, always set it as MKL-DNN layout
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> eltwise_fwd_pd =
eltwise_fwd->GetEltwiseFwdPd();
// Allocate dst tensor, always set it as MKL-DNN layout
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true);
auto dst_pd = eltwise_fwd_pd->dst_primitive_desc();
auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST;
dnn_shape_dst.SetMklLayout(&dst_pd);
dnn_shape_dst.SetElemType(MklDnnType<T>());
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
@ -549,8 +609,8 @@ class MklReluOpBase : public OpKernel {
}
private:
engine cpu_engine = engine(engine::cpu, 0);
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd;
engine cpu_engine = engine(ENGINE_CPU, 0);
std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
protected:
float alpha_;
@ -604,8 +664,8 @@ class MklReluGradOpBase : public OpKernel {
// get a eltwise bwd from primitive pool
memory::dims src_dims = {};
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
memory::desc diff_dst_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims);
@ -616,18 +676,18 @@ class MklReluGradOpBase : public OpKernel {
src_md = dnn_shape_src.GetMklLayout();
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
MKL_TENSOR_FORMAT src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
auto src_tf_data_format =
MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
src_tf_data_format);
diff_dst_md =
memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format);
diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
GET_TENSOR_FORMAT(src_mkl_data_format));
} else if (!dnn_shape_src.IsMklTensor() &&
dnn_shape_diff_dst.IsMklTensor()) {
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
memory::format diff_dst_mkl_data_format =
MKL_TENSOR_FORMAT diff_dst_mkl_data_format =
dnn_shape_diff_dst.GetTfDataFormat();
auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
@ -637,8 +697,8 @@ class MklReluGradOpBase : public OpKernel {
diff_dst_tf_data_format)
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
diff_dst_tf_data_format);
src_md =
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
src_md = memory::desc(src_dims, MklDnnType<T>(),
GET_TENSOR_FORMAT(diff_dst_mkl_data_format));
} else {
src_md = dnn_shape_src.GetMklLayout();
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
@ -649,7 +709,7 @@ class MklReluGradOpBase : public OpKernel {
// format. So we set common memory descriptor in MKL format, if any of the
// inputs are in MKL format. Let's get memory descriptor that we will use
// for both the inputs.
memory::desc common_md({}, memory::data_undef, memory::format_undef);
memory::desc common_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
} else {
@ -662,30 +722,32 @@ class MklReluGradOpBase : public OpKernel {
beta_);
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
// check whether need reorder for src / diff_dst
const T* src_data = src_tensor.flat<T>().data();
if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) {
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(
eltwise_bwd_pd.get()->diff_src_primitive_desc());
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
src_data = const_cast<T*>(
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
}
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) {
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
eltwise_bwd)) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(
eltwise_bwd_pd.get()->diff_src_primitive_desc());
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
diff_dst_data = const_cast<T*>(
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
}
// allocate diff_src tensor
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
auto diff_src_pd = eltwise_bwd_pd->PRIMITIVE_DESC_DIFF_SRC;
dnn_shape_diff_src.SetMklTensor(true);
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
@ -726,8 +788,8 @@ class MklReluGradOpBase : public OpKernel {
}
private:
engine cpu_engine = engine(engine::cpu, 0);
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd;
engine cpu_engine = engine(ENGINE_CPU, 0);
std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
protected:
float alpha_;
@ -735,12 +797,13 @@ class MklReluGradOpBase : public OpKernel {
};
template <typename Device, typename T>
class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
class MklReluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
public:
~MklReluOp() {}
explicit MklReluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {}
: MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor
@ -764,12 +827,14 @@ class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
};
template <typename Device, typename T>
class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
class MklReluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
public:
~MklReluGradOp() {}
explicit MklReluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {}
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor
@ -799,12 +864,12 @@ class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
};
template <typename Device, typename T>
class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
class MklEluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_elu> {
public:
~MklEluOp() {}
explicit MklEluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {}
: MklReluOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f, 0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor
@ -832,12 +897,14 @@ class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
};
template <typename Device, typename T>
class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
class MklEluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu> {
public:
~MklEluGradOp() {}
explicit MklEluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {}
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f,
0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor
@ -872,12 +939,13 @@ class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
};
template <typename Device, typename T>
class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
public:
~MklTanhOp() {}
explicit MklTanhOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {}
: MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f, 0.0f) {
}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t src_index = 0; // index of src input tensor
@ -904,12 +972,14 @@ class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
};
template <typename Device, typename T>
class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
class MklTanhGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh> {
public:
~MklTanhGradOp() {}
explicit MklTanhGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {}
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f,
0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) {
const size_t diff_dst_index = 0; // index of diff_dst input tensor
@ -943,12 +1013,13 @@ class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
#define RELU6_UPPER_BOUND 6.0f
template <typename Device, typename T>
class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
class MklRelu6Op
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
public:
~MklRelu6Op() {}
explicit MklRelu6Op(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_bounded_relu>(
: MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
context, RELU6_UPPER_BOUND, 0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) {
@ -973,12 +1044,12 @@ class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
template <typename Device, typename T>
class MklRelu6GradOp
: public MklReluGradOpBase<Device, T, eltwise_bounded_relu> {
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
public:
~MklRelu6GradOp() {}
explicit MklRelu6GradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_bounded_relu>(
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
context, RELU6_UPPER_BOUND, 0.0f) {}
virtual void Compute_Scalar(OpKernelContext* context) {
@ -1007,12 +1078,13 @@ class MklRelu6GradOp
};
template <typename Device, typename T>
class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
class MklLeakyReluOp
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
public:
~MklLeakyReluOp() {}
explicit MklLeakyReluOp(OpKernelConstruction* context)
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
: MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
float alpha;
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
OP_REQUIRES(
@ -1044,12 +1116,14 @@ class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
};
template <typename Device, typename T>
class MklLeakyReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
class MklLeakyReluGradOp
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
public:
~MklLeakyReluGradOp() {}
explicit MklLeakyReluGradOp(OpKernelConstruction* context)
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
0.0f) {
float alpha;
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
OP_REQUIRES(

View File

@ -77,6 +77,7 @@ namespace tensorflow {
memory::desc({dims}, MklDnnType<type>(), fm)
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
#define MEMORY_PRIMITIVE_DESC memory::desc
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) MEMORY_PRIMITIVE_DESC(md)
#define MKL_FMT_TAG mkl_fmt_tag
#define MKL_TENSOR_FORMAT MklTensorFormat
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
@ -170,6 +171,8 @@ namespace tensorflow {
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
MEMORY_PRIMITIVE_DESC(md, engine)
#define MKL_FMT_TAG tf_fmt
#define MKL_TENSOR_FORMAT memory::format
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked