Merge pull request #36469 from Intel-tensorflow:mabuzain/relu_1.x
PiperOrigin-RevId: 294344071 Change-Id: I739ce33c0f7617e53bbc95d68f883e422e497a41
This commit is contained in:
commit
61d19090c8
@ -654,6 +654,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
|
mkl_op_registry::GetMklOpName(csinfo_.quantize_v2),
|
||||||
CopyAttrsAll, QuantizeOpRewrite,
|
CopyAttrsAll, QuantizeOpRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu),
|
||||||
CopyAttrsAll, AlwaysRewrite,
|
CopyAttrsAll, AlwaysRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
@ -666,10 +667,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
|||||||
rinfo_.push_back(
|
rinfo_.push_back(
|
||||||
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
|
{csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
rinfo_.push_back(
|
rinfo_.push_back(
|
||||||
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
{csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||||
#endif // !ENABLE_MKLDNN_V1
|
|
||||||
// Disable these two MKL operators for now due to some test failures caused
|
// Disable these two MKL operators for now due to some test failures caused
|
||||||
// by these two ops
|
// by these two ops
|
||||||
/*
|
/*
|
||||||
@ -682,7 +683,6 @@ rinfo_.push_back({csinfo_.tanh_grad,
|
|||||||
CopyAttrsAll, AlwaysRewrite,
|
CopyAttrsAll, AlwaysRewrite,
|
||||||
kRewriteForLayoutPropagation});
|
kRewriteForLayoutPropagation});
|
||||||
*/
|
*/
|
||||||
#ifndef ENABLE_MKLDNN_V1
|
|
||||||
rinfo_.push_back(
|
rinfo_.push_back(
|
||||||
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
{csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape),
|
||||||
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
|
||||||
|
|||||||
@ -16,6 +16,8 @@ limitations under the License.
|
|||||||
// See docs in ../ops/nn_ops.cc.
|
// See docs in ../ops/nn_ops.cc.
|
||||||
#ifdef INTEL_MKL
|
#ifdef INTEL_MKL
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
#include "mkldnn.hpp"
|
#include "mkldnn.hpp"
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/numeric_op.h"
|
#include "tensorflow/core/framework/numeric_op.h"
|
||||||
@ -23,24 +25,24 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
#include "tensorflow/core/util/mkl_types.h"
|
||||||
#include "tensorflow/core/util/mkl_util.h"
|
#include "tensorflow/core/util/mkl_util.h"
|
||||||
|
|
||||||
using mkldnn::algorithm;
|
using mkldnn::algorithm;
|
||||||
using mkldnn::eltwise_bounded_relu;
|
|
||||||
using mkldnn::eltwise_elu;
|
|
||||||
using mkldnn::eltwise_forward;
|
using mkldnn::eltwise_forward;
|
||||||
using mkldnn::eltwise_relu;
|
|
||||||
using mkldnn::eltwise_tanh;
|
|
||||||
using mkldnn::memory;
|
using mkldnn::memory;
|
||||||
using mkldnn::prop_kind;
|
using mkldnn::prop_kind;
|
||||||
using mkldnn::stream;
|
using mkldnn::stream;
|
||||||
|
|
||||||
|
using EltwiseFwdPd = mkldnn::eltwise_forward::primitive_desc;
|
||||||
|
using EltwiseBwdPd = mkldnn::eltwise_backward::primitive_desc;
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class MklEltwiseFwdParams {
|
class MklEltwiseFwdParams {
|
||||||
public:
|
public:
|
||||||
memory::dims src_dims; // check if this is needed
|
memory::dims src_dims;
|
||||||
memory::desc src_md;
|
memory::desc src_md;
|
||||||
algorithm alg_kind;
|
algorithm alg_kind;
|
||||||
float alpha;
|
float alpha;
|
||||||
@ -59,11 +61,12 @@ template <typename T>
|
|||||||
class MklEltwiseFwdPrimitive : public MklPrimitive {
|
class MklEltwiseFwdPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
|
explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
|
||||||
: cpu_engine_(engine::cpu, 0) {
|
: cpu_engine_(ENGINE_CPU, 0) {
|
||||||
// store expected format
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
context_.src_fmt =
|
context_.src_fmt =
|
||||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
||||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
#endif
|
||||||
|
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||||
|
|
||||||
// create eltwise primitive
|
// create eltwise primitive
|
||||||
if (context_.eltwise_fwd == nullptr) {
|
if (context_.eltwise_fwd == nullptr) {
|
||||||
@ -80,24 +83,38 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
context_.src_mem->set_data_handle(
|
context_.src_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(src_data)));
|
static_cast<void*>(const_cast<T*>(src_data)));
|
||||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
|
||||||
|
|
||||||
// after execution, set data handle back
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||||
|
context_.fwd_primitives_args.size());
|
||||||
|
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||||
|
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
||||||
|
context_.fwd_primitives_args.at(i));
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// After execution, set data handle back.
|
||||||
context_.src_mem->set_data_handle(DummyData);
|
context_.src_mem->set_data_handle(DummyData);
|
||||||
context_.dst_mem->set_data_handle(DummyData);
|
context_.dst_mem->set_data_handle(DummyData);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() {
|
std::shared_ptr<EltwiseFwdPd> GetEltwiseFwdPd() { return context_.fwd_pd; }
|
||||||
return context_.fwd_pd;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
// In MKL-DNN v1.x, memory format tags only provide a partial description
|
||||||
|
// of the memory layout. Hence, these functions are disabled for v1.x.
|
||||||
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
|
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
|
||||||
|
#endif
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
|
// Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
|
||||||
struct EltwiseFwdContext {
|
struct EltwiseFwdContext {
|
||||||
// expected memory format for this primitive instance
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
// Expected memory format for this primitive instance
|
||||||
mkldnn::memory::format src_fmt;
|
mkldnn::memory::format src_fmt;
|
||||||
|
#endif
|
||||||
|
|
||||||
// MKLDNN memory
|
// MKLDNN memory
|
||||||
std::shared_ptr<memory> src_mem;
|
std::shared_ptr<memory> src_mem;
|
||||||
@ -105,14 +122,14 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
|
|
||||||
// desc & prmitive desc
|
// desc & prmitive desc
|
||||||
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
|
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
|
||||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
|
std::shared_ptr<EltwiseFwdPd> fwd_pd;
|
||||||
|
|
||||||
// memory desc
|
// memory desc
|
||||||
std::shared_ptr<memory::desc> src_md;
|
std::shared_ptr<memory::desc> src_md;
|
||||||
std::shared_ptr<memory::desc> dst_md;
|
std::shared_ptr<memory::desc> dst_md;
|
||||||
|
|
||||||
// memory primitive desc
|
// memory primitive desc
|
||||||
std::shared_ptr<memory::primitive_desc> src_mpd;
|
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
|
||||||
|
|
||||||
// Eltwise primitive
|
// Eltwise primitive
|
||||||
std::shared_ptr<mkldnn::primitive> eltwise_fwd;
|
std::shared_ptr<mkldnn::primitive> eltwise_fwd;
|
||||||
@ -120,8 +137,15 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
std::shared_ptr<stream> fwd_stream;
|
std::shared_ptr<stream> fwd_stream;
|
||||||
std::vector<mkldnn::primitive> fwd_primitives;
|
std::vector<mkldnn::primitive> fwd_primitives;
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
std::vector<std::unordered_map<int, memory>> fwd_primitives_args;
|
||||||
|
#endif
|
||||||
|
|
||||||
EltwiseFwdContext()
|
EltwiseFwdContext()
|
||||||
: src_fmt(memory::format::any),
|
:
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
src_fmt(memory::format::any),
|
||||||
|
#endif
|
||||||
src_mem(nullptr),
|
src_mem(nullptr),
|
||||||
dst_mem(nullptr),
|
dst_mem(nullptr),
|
||||||
fwd_desc(nullptr),
|
fwd_desc(nullptr),
|
||||||
@ -130,31 +154,43 @@ class MklEltwiseFwdPrimitive : public MklPrimitive {
|
|||||||
dst_md(nullptr),
|
dst_md(nullptr),
|
||||||
src_mpd(nullptr),
|
src_mpd(nullptr),
|
||||||
eltwise_fwd(nullptr),
|
eltwise_fwd(nullptr),
|
||||||
fwd_stream(nullptr) {}
|
fwd_stream(nullptr) {
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Eltwise forward primitive setup
|
// Eltwise forward primitive setup
|
||||||
void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
|
void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||||
// create memory descriptors for eltwise data with specified format
|
// create memory descriptors for eltwise data with specified format
|
||||||
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
|
||||||
context_.src_mpd.reset(
|
|
||||||
new memory::primitive_desc(*context_.src_md, cpu_engine_));
|
|
||||||
|
|
||||||
// create a eltwise
|
context_.src_mpd.reset(
|
||||||
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||||
|
|
||||||
|
// Create an eltwise forward descriptor and primitive descriptor
|
||||||
|
context_.fwd_desc.reset(new eltwise_forward::desc(
|
||||||
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
|
prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
|
||||||
fwdParams.alpha, fwdParams.beta));
|
fwdParams.alpha, fwdParams.beta));
|
||||||
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
|
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||||
*context_.fwd_desc, cpu_engine_));
|
auto fwd_pd = context_.fwd_pd.get();
|
||||||
|
|
||||||
// create memory primitive based on dummy data
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
// Create memory primitive based on dummy data
|
||||||
|
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_SRC,
|
||||||
|
cpu_engine_, DummyData));
|
||||||
|
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR(fwd_pd->PRIMITIVE_DESC_DST,
|
||||||
|
cpu_engine_, DummyData));
|
||||||
|
// Create eltwise primitive and add it to net
|
||||||
|
context_.eltwise_fwd.reset(new eltwise_forward(*context_.fwd_pd));
|
||||||
|
context_.fwd_primitives_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
|
{ MKLDNN_ARG_DST,
|
||||||
|
*context_.dst_mem }});
|
||||||
|
#else
|
||||||
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
||||||
context_.dst_mem.reset(
|
context_.dst_mem.reset(
|
||||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||||
|
context_.eltwise_fwd.reset(new eltwise_forward(
|
||||||
// create eltwise primitive and add it to net
|
|
||||||
context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(
|
|
||||||
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
|
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
|
||||||
|
#endif
|
||||||
|
|
||||||
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
|
context_.fwd_primitives.push_back(*context_.eltwise_fwd);
|
||||||
}
|
}
|
||||||
@ -170,18 +206,16 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
const MklEltwiseFwdParams<T>& fwdParams) {
|
const MklEltwiseFwdParams<T>& fwdParams) {
|
||||||
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
|
MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
|
||||||
|
|
||||||
auto src_fmt =
|
|
||||||
static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
|
|
||||||
|
|
||||||
// Get a eltwise fwd primitive from the cached pool
|
// Get a eltwise fwd primitive from the cached pool
|
||||||
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
|
||||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
|
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(
|
||||||
src_fmt));
|
fwdParams));
|
||||||
if (eltwise_forward == nullptr) {
|
if (eltwise_forward == nullptr) {
|
||||||
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
|
eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
|
||||||
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
|
||||||
fwdParams, src_fmt, eltwise_forward);
|
fwdParams, eltwise_forward);
|
||||||
}
|
}
|
||||||
|
|
||||||
return eltwise_forward;
|
return eltwise_forward;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -194,8 +228,7 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
MklEltwiseFwdPrimitiveFactory() {}
|
MklEltwiseFwdPrimitiveFactory() {}
|
||||||
~MklEltwiseFwdPrimitiveFactory() {}
|
~MklEltwiseFwdPrimitiveFactory() {}
|
||||||
|
|
||||||
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
|
static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||||
memory::format src_fmt) {
|
|
||||||
string prefix = "eltwise_fwd";
|
string prefix = "eltwise_fwd";
|
||||||
FactoryKeyCreator key_creator;
|
FactoryKeyCreator key_creator;
|
||||||
key_creator.AddAsKey(prefix);
|
key_creator.AddAsKey(prefix);
|
||||||
@ -203,19 +236,20 @@ class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
|
||||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
|
||||||
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
|
||||||
key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
key_creator.AddAsKey<int>(static_cast<int>(fwdParams.src_md.data.format));
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
return key_creator.GetKey();
|
return key_creator.GetKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams) {
|
||||||
memory::format src_fmt) {
|
string key = CreateKey(fwdParams);
|
||||||
string key = CreateKey(fwdParams, src_fmt);
|
|
||||||
return this->GetOp(key);
|
return this->GetOp(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
|
||||||
memory::format src_fmt, MklPrimitive* op) {
|
MklPrimitive* op) {
|
||||||
string key = CreateKey(fwdParams, src_fmt);
|
string key = CreateKey(fwdParams);
|
||||||
this->SetOp(key, op);
|
this->SetOp(key, op);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -243,12 +277,14 @@ template <typename T>
|
|||||||
class MklEltwiseBwdPrimitive : public MklPrimitive {
|
class MklEltwiseBwdPrimitive : public MklPrimitive {
|
||||||
public:
|
public:
|
||||||
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
|
explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
|
||||||
: cpu_engine_(engine::cpu, 0) {
|
: cpu_engine_(ENGINE_CPU, 0) {
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
context_.src_fmt =
|
context_.src_fmt =
|
||||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||||
context_.diff_dst_fmt =
|
context_.diff_dst_fmt =
|
||||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
||||||
context_.bwd_stream.reset(new stream(stream::kind::eager));
|
#endif
|
||||||
|
context_.bwd_stream.reset(new stream(CPU_STREAM(cpu_engine_)));
|
||||||
// create eltwise primitive
|
// create eltwise primitive
|
||||||
if (context_.eltwise_bwd == nullptr) {
|
if (context_.eltwise_bwd == nullptr) {
|
||||||
Setup(bwdParams);
|
Setup(bwdParams);
|
||||||
@ -267,7 +303,17 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
|||||||
context_.diff_dst_mem->set_data_handle(
|
context_.diff_dst_mem->set_data_handle(
|
||||||
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
static_cast<void*>(const_cast<T*>(diff_dst_data)));
|
||||||
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
DCHECK_EQ(context_.bwd_primitives.size(),
|
||||||
|
context_.bwd_primitives_args.size());
|
||||||
|
for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) {
|
||||||
|
context_.bwd_primitives.at(i).execute(*context_.bwd_stream,
|
||||||
|
context_.bwd_primitives_args.at(i));
|
||||||
|
}
|
||||||
|
#else
|
||||||
context_.bwd_stream->submit(context_.bwd_primitives);
|
context_.bwd_stream->submit(context_.bwd_primitives);
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
// after execution, set data handle back
|
// after execution, set data handle back
|
||||||
context_.src_mem->set_data_handle(DummyData);
|
context_.src_mem->set_data_handle(DummyData);
|
||||||
@ -275,52 +321,61 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
|||||||
context_.diff_src_mem->set_data_handle(DummyData);
|
context_.diff_src_mem->set_data_handle(DummyData);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() {
|
std::shared_ptr<EltwiseBwdPd> GetEltwiseBwdPd() { return context_.bwd_pd; }
|
||||||
return context_.bwd_pd;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
|
memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
|
||||||
|
|
||||||
memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
|
memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
|
// Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
|
||||||
struct EltwiseBwdContext {
|
struct EltwiseBwdContext {
|
||||||
// expected memory format for this primitive instance
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
memory::format src_fmt;
|
memory::format src_fmt;
|
||||||
memory::format diff_dst_fmt;
|
memory::format diff_dst_fmt;
|
||||||
|
#endif
|
||||||
|
|
||||||
// MKLDNN memory
|
// MKLDNN memory
|
||||||
std::shared_ptr<memory> src_mem;
|
std::shared_ptr<memory> src_mem;
|
||||||
std::shared_ptr<memory> diff_dst_mem;
|
std::shared_ptr<memory> diff_dst_mem;
|
||||||
std::shared_ptr<memory> diff_src_mem;
|
std::shared_ptr<memory> diff_src_mem;
|
||||||
|
|
||||||
// desc & prmitive desc
|
// Backward Eltwise descriptor.
|
||||||
std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
|
std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
|
||||||
|
|
||||||
// memory desc
|
// Memory descriptors.
|
||||||
std::shared_ptr<memory::desc> src_md;
|
std::shared_ptr<memory::desc> src_md;
|
||||||
std::shared_ptr<memory::desc> diff_dst_md;
|
std::shared_ptr<memory::desc> diff_dst_md;
|
||||||
std::shared_ptr<memory::desc> common_md;
|
std::shared_ptr<memory::desc> common_md;
|
||||||
|
|
||||||
// memory primitive desc
|
// Memory primitive descriptor.
|
||||||
std::shared_ptr<memory::primitive_desc> src_mpd;
|
// TODO(gzmkl): for MKL-DNN 1.0, src_mpd is same as src_md
|
||||||
std::shared_ptr<memory::primitive_desc> diff_dst_mpd;
|
// So it should be removed once MKL-DNN 0.x is cleaned.
|
||||||
|
std::shared_ptr<MEMORY_PRIMITIVE_DESC> src_mpd;
|
||||||
|
std::shared_ptr<MEMORY_PRIMITIVE_DESC> diff_dst_mpd;
|
||||||
|
|
||||||
// fwd primitive desc
|
// Forward and backward descriptors and primitive descriptors.
|
||||||
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
|
std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
|
||||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
|
std::shared_ptr<EltwiseFwdPd> fwd_pd;
|
||||||
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd;
|
std::shared_ptr<EltwiseBwdPd> bwd_pd;
|
||||||
|
|
||||||
// Eltwise primitive
|
// Eltwise primitive.
|
||||||
std::shared_ptr<mkldnn::primitive> eltwise_bwd;
|
std::shared_ptr<mkldnn::primitive> eltwise_bwd;
|
||||||
|
|
||||||
std::shared_ptr<stream> bwd_stream;
|
std::shared_ptr<stream> bwd_stream;
|
||||||
std::vector<mkldnn::primitive> bwd_primitives;
|
std::vector<mkldnn::primitive> bwd_primitives;
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
std::vector<MemoryArgsMap> bwd_primitives_args;
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
EltwiseBwdContext()
|
EltwiseBwdContext()
|
||||||
: src_fmt(memory::format::any),
|
:
|
||||||
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
|
src_fmt(memory::format::any),
|
||||||
diff_dst_fmt(memory::format::any),
|
diff_dst_fmt(memory::format::any),
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
src_mem(nullptr),
|
src_mem(nullptr),
|
||||||
diff_dst_mem(nullptr),
|
diff_dst_mem(nullptr),
|
||||||
diff_src_mem(nullptr),
|
diff_src_mem(nullptr),
|
||||||
@ -333,42 +388,58 @@ class MklEltwiseBwdPrimitive : public MklPrimitive {
|
|||||||
fwd_pd(nullptr),
|
fwd_pd(nullptr),
|
||||||
bwd_pd(nullptr),
|
bwd_pd(nullptr),
|
||||||
eltwise_bwd(nullptr),
|
eltwise_bwd(nullptr),
|
||||||
bwd_stream(nullptr) {}
|
bwd_stream(nullptr) {
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Eltwise backward primitive setup
|
// Eltwise backward primitive setup
|
||||||
void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
|
void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||||
// create memory descriptors for eltwise data w/ no specified format
|
// Create memory descriptors for eltwise data w/ no specified format
|
||||||
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||||
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
|
||||||
|
|
||||||
context_.src_mpd.reset(
|
context_.src_mpd.reset(
|
||||||
new memory::primitive_desc(*context_.src_md, cpu_engine_));
|
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_));
|
||||||
context_.diff_dst_mpd.reset(
|
context_.diff_dst_mpd.reset(
|
||||||
new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_));
|
new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.diff_dst_md, cpu_engine_));
|
||||||
|
|
||||||
// create forward eltwise primitive
|
// Create forward eltwise primitive.
|
||||||
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
|
||||||
prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
|
prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
|
||||||
bwdParams.alpha, bwdParams.beta));
|
bwdParams.alpha, bwdParams.beta));
|
||||||
context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
|
context_.fwd_pd.reset(new EltwiseFwdPd(*context_.fwd_desc, cpu_engine_));
|
||||||
*context_.fwd_desc, cpu_engine_));
|
|
||||||
context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
|
context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
|
||||||
bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
|
bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
|
||||||
bwdParams.alpha, bwdParams.beta));
|
bwdParams.alpha, bwdParams.beta));
|
||||||
context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc(
|
context_.bwd_pd.reset(
|
||||||
*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
|
new EltwiseBwdPd(*context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
|
||||||
|
|
||||||
// create memory primitive based on dummy data
|
auto bwd_pd = context_.bwd_pd.get();
|
||||||
|
|
||||||
|
#ifdef ENABLE_MKLDNN_V1
|
||||||
|
// Create memory primitive based on dummy data.
|
||||||
|
context_.src_mem.reset(new MEMORY_CONSTRUCTOR(bwd_pd->PRIMITIVE_DESC_SRC,
|
||||||
|
cpu_engine_, DummyData));
|
||||||
|
context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||||
|
bwd_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData));
|
||||||
|
context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR(
|
||||||
|
bwd_pd->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_, DummyData));
|
||||||
|
// Create eltwise primitive and add it to net.
|
||||||
|
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(*context_.bwd_pd));
|
||||||
|
context_.bwd_primitives_args.push_back(
|
||||||
|
{{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||||
|
{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem},
|
||||||
|
{ MKLDNN_ARG_DIFF_SRC,
|
||||||
|
*context_.diff_src_mem }});
|
||||||
|
#else
|
||||||
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
|
||||||
context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
|
context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
|
||||||
context_.diff_src_mem.reset(new memory(
|
context_.diff_src_mem.reset(new memory(
|
||||||
context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
|
context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
|
||||||
|
|
||||||
// create eltwise primitive and add it to net
|
|
||||||
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
|
context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
|
||||||
*context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
|
*context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
|
||||||
*context_.diff_src_mem));
|
*context_.diff_src_mem));
|
||||||
|
#endif // ENABLE_MKLDNN_V1
|
||||||
|
|
||||||
context_.bwd_primitives.push_back(*context_.eltwise_bwd);
|
context_.bwd_primitives.push_back(*context_.eltwise_bwd);
|
||||||
}
|
}
|
||||||
@ -388,20 +459,15 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
const MklEltwiseBwdParams<T>& bwdParams) {
|
const MklEltwiseBwdParams<T>& bwdParams) {
|
||||||
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
|
MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
|
||||||
|
|
||||||
auto src_fmt =
|
|
||||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
|
||||||
auto diff_dst_fmt =
|
|
||||||
static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
|
|
||||||
|
|
||||||
// try to find a suitable one in pool
|
// try to find a suitable one in pool
|
||||||
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
|
||||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
|
||||||
bwdParams, src_fmt, diff_dst_fmt));
|
bwdParams));
|
||||||
|
|
||||||
if (eltwise_backward == nullptr) {
|
if (eltwise_backward == nullptr) {
|
||||||
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
|
eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
|
||||||
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
|
||||||
bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
|
bwdParams, eltwise_backward);
|
||||||
}
|
}
|
||||||
return eltwise_backward;
|
return eltwise_backward;
|
||||||
}
|
}
|
||||||
@ -412,9 +478,7 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
|
static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||||
const memory::format& src_fmt,
|
|
||||||
const memory::format& diff_dst_fmt) {
|
|
||||||
string prefix = "eltwise_bwd";
|
string prefix = "eltwise_bwd";
|
||||||
FactoryKeyCreator key_creator;
|
FactoryKeyCreator key_creator;
|
||||||
key_creator.AddAsKey(prefix);
|
key_creator.AddAsKey(prefix);
|
||||||
@ -422,22 +486,20 @@ class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
|||||||
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
|
||||||
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
|
||||||
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
|
||||||
key_creator.AddAsKey(static_cast<int>(src_fmt));
|
#ifndef ENABLE_MKLDNN_V1
|
||||||
key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
|
key_creator.AddAsKey(static_cast<int>(bwdParams.common_md.data.format));
|
||||||
|
#endif // !ENABLE_MKLDNN_V1
|
||||||
return key_creator.GetKey();
|
return key_creator.GetKey();
|
||||||
}
|
}
|
||||||
|
|
||||||
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams) {
|
||||||
const memory::format& src_fmt,
|
string key = CreateKey(bwdParams);
|
||||||
const memory::format& diff_dst_fmt) {
|
|
||||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
|
||||||
return this->GetOp(key);
|
return this->GetOp(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
|
||||||
const memory::format& src_fmt,
|
MklPrimitive* op) {
|
||||||
const memory::format& diff_dst_fmt, MklPrimitive* op) {
|
string key = CreateKey(bwdParams);
|
||||||
string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
|
|
||||||
this->SetOp(key, op);
|
this->SetOp(key, op);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -481,7 +543,7 @@ class MklReluOpBase : public OpKernel {
|
|||||||
// Set DNN primitive - src
|
// Set DNN primitive - src
|
||||||
MklDnnData<T> src(&cpu_engine);
|
MklDnnData<T> src(&cpu_engine);
|
||||||
memory::dims src_dims;
|
memory::dims src_dims;
|
||||||
memory::desc src_md({}, memory::data_undef, memory::format_undef);
|
memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||||
if (dnn_shape_src.IsMklTensor()) {
|
if (dnn_shape_src.IsMklTensor()) {
|
||||||
src_md = dnn_shape_src.GetMklLayout();
|
src_md = dnn_shape_src.GetMklLayout();
|
||||||
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
|
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
|
||||||
@ -492,31 +554,29 @@ class MklReluOpBase : public OpKernel {
|
|||||||
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
|
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
// get a eltwise fwd from primitive pool
|
// Try to get an eltwise forward primitive from caching pool
|
||||||
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
|
MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha_,
|
||||||
beta_);
|
beta_);
|
||||||
|
|
||||||
MklEltwiseFwdPrimitive<T>* eltwise_fwd =
|
MklEltwiseFwdPrimitive<T>* eltwise_fwd =
|
||||||
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
|
MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
|
||||||
|
|
||||||
// prepare for execuation
|
auto eltwise_fwd_pd = eltwise_fwd->GetEltwiseFwdPd();
|
||||||
|
|
||||||
|
// Check if src needs to be reordered
|
||||||
const T* src_data = src_tensor.flat<T>().data();
|
const T* src_data = src_tensor.flat<T>().data();
|
||||||
// check wehther src need to reorder
|
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_fwd_pd, eltwise_fwd)) {
|
||||||
if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
|
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
auto src_target_pd = memory::primitive_desc(
|
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||||
{{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()},
|
eltwise_fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine));
|
||||||
cpu_engine);
|
|
||||||
src.CheckReorderToOpMem(src_target_pd);
|
|
||||||
src_data = const_cast<T*>(
|
src_data = const_cast<T*>(
|
||||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate dst tensor, always set it as MKL-DNN layout
|
// Allocate dst tensor, always set it as MKL-DNN layout
|
||||||
std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> eltwise_fwd_pd =
|
|
||||||
eltwise_fwd->GetEltwiseFwdPd();
|
|
||||||
if (dnn_shape_src.IsMklTensor()) {
|
if (dnn_shape_src.IsMklTensor()) {
|
||||||
dnn_shape_dst.SetMklTensor(true);
|
dnn_shape_dst.SetMklTensor(true);
|
||||||
auto dst_pd = eltwise_fwd_pd->dst_primitive_desc();
|
auto dst_pd = eltwise_fwd_pd->PRIMITIVE_DESC_DST;
|
||||||
dnn_shape_dst.SetMklLayout(&dst_pd);
|
dnn_shape_dst.SetMklLayout(&dst_pd);
|
||||||
dnn_shape_dst.SetElemType(MklDnnType<T>());
|
dnn_shape_dst.SetElemType(MklDnnType<T>());
|
||||||
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
|
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
|
||||||
@ -549,8 +609,8 @@ class MklReluOpBase : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
engine cpu_engine = engine(engine::cpu, 0);
|
engine cpu_engine = engine(ENGINE_CPU, 0);
|
||||||
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd;
|
std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
float alpha_;
|
float alpha_;
|
||||||
@ -604,8 +664,8 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
|
|
||||||
// get a eltwise bwd from primitive pool
|
// get a eltwise bwd from primitive pool
|
||||||
memory::dims src_dims = {};
|
memory::dims src_dims = {};
|
||||||
memory::desc src_md({}, memory::data_undef, memory::format_undef);
|
memory::desc src_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||||
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
|
memory::desc diff_dst_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||||
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
|
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
|
||||||
src_dims = TFShapeToMklDnnDims(src_tensor.shape());
|
src_dims = TFShapeToMklDnnDims(src_tensor.shape());
|
||||||
auto src_strides = CalculateTFStrides(src_dims);
|
auto src_strides = CalculateTFStrides(src_dims);
|
||||||
@ -616,18 +676,18 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
src_md = dnn_shape_src.GetMklLayout();
|
src_md = dnn_shape_src.GetMklLayout();
|
||||||
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
|
src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
|
||||||
|
|
||||||
memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
|
MKL_TENSOR_FORMAT src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
|
||||||
auto src_tf_data_format =
|
auto src_tf_data_format =
|
||||||
MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
|
MklDnnDataFormatToTFDataFormat(src_mkl_data_format);
|
||||||
auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
|
auto diff_dst_dims = TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(),
|
||||||
src_tf_data_format);
|
src_tf_data_format);
|
||||||
diff_dst_md =
|
diff_dst_md = memory::desc(diff_dst_dims, MklDnnType<T>(),
|
||||||
memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format);
|
GET_TENSOR_FORMAT(src_mkl_data_format));
|
||||||
} else if (!dnn_shape_src.IsMklTensor() &&
|
} else if (!dnn_shape_src.IsMklTensor() &&
|
||||||
dnn_shape_diff_dst.IsMklTensor()) {
|
dnn_shape_diff_dst.IsMklTensor()) {
|
||||||
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
|
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
|
||||||
|
|
||||||
memory::format diff_dst_mkl_data_format =
|
MKL_TENSOR_FORMAT diff_dst_mkl_data_format =
|
||||||
dnn_shape_diff_dst.GetTfDataFormat();
|
dnn_shape_diff_dst.GetTfDataFormat();
|
||||||
auto diff_dst_tf_data_format =
|
auto diff_dst_tf_data_format =
|
||||||
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
|
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
|
||||||
@ -637,8 +697,8 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
diff_dst_tf_data_format)
|
diff_dst_tf_data_format)
|
||||||
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
|
: TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
|
||||||
diff_dst_tf_data_format);
|
diff_dst_tf_data_format);
|
||||||
src_md =
|
src_md = memory::desc(src_dims, MklDnnType<T>(),
|
||||||
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
|
GET_TENSOR_FORMAT(diff_dst_mkl_data_format));
|
||||||
} else {
|
} else {
|
||||||
src_md = dnn_shape_src.GetMklLayout();
|
src_md = dnn_shape_src.GetMklLayout();
|
||||||
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
|
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
|
||||||
@ -649,7 +709,7 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
// format. So we set common memory descriptor in MKL format, if any of the
|
// format. So we set common memory descriptor in MKL format, if any of the
|
||||||
// inputs are in MKL format. Let's get memory descriptor that we will use
|
// inputs are in MKL format. Let's get memory descriptor that we will use
|
||||||
// for both the inputs.
|
// for both the inputs.
|
||||||
memory::desc common_md({}, memory::data_undef, memory::format_undef);
|
memory::desc common_md({}, MEMORY_DATA_TYPE_UNDEF, MEMORY_FORMAT_UNDEF);
|
||||||
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
|
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
|
||||||
common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
|
common_md = dnn_shape_src.IsMklTensor() ? src_md : diff_dst_md;
|
||||||
} else {
|
} else {
|
||||||
@ -662,30 +722,32 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
beta_);
|
beta_);
|
||||||
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
|
MklEltwiseBwdPrimitive<T>* eltwise_bwd =
|
||||||
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
|
MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
|
||||||
|
|
||||||
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
|
auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
|
||||||
|
|
||||||
// check whether need reorder for src / diff_dst
|
// check whether need reorder for src / diff_dst
|
||||||
const T* src_data = src_tensor.flat<T>().data();
|
const T* src_data = src_tensor.flat<T>().data();
|
||||||
if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) {
|
if (IS_SRC_REORDER_NEEDED(src_md, eltwise_bwd_pd, eltwise_bwd)) {
|
||||||
src.SetUsrMem(src_md, &src_tensor);
|
src.SetUsrMem(src_md, &src_tensor);
|
||||||
src.CheckReorderToOpMem(
|
src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||||
eltwise_bwd_pd.get()->diff_src_primitive_desc());
|
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
|
||||||
src_data = const_cast<T*>(
|
src_data = const_cast<T*>(
|
||||||
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
|
||||||
}
|
}
|
||||||
|
|
||||||
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
|
const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
|
||||||
if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) {
|
if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, eltwise_bwd_pd,
|
||||||
|
eltwise_bwd)) {
|
||||||
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
|
||||||
diff_dst.CheckReorderToOpMem(
|
diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||||
eltwise_bwd_pd.get()->diff_src_primitive_desc());
|
eltwise_bwd_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine));
|
||||||
diff_dst_data = const_cast<T*>(
|
diff_dst_data = const_cast<T*>(
|
||||||
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
|
reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// allocate diff_src tensor
|
// allocate diff_src tensor
|
||||||
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
|
if (dnn_shape_src.IsMklTensor() || dnn_shape_diff_dst.IsMklTensor()) {
|
||||||
auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
|
auto diff_src_pd = eltwise_bwd_pd->PRIMITIVE_DESC_DIFF_SRC;
|
||||||
dnn_shape_diff_src.SetMklTensor(true);
|
dnn_shape_diff_src.SetMklTensor(true);
|
||||||
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
|
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
|
||||||
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
|
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
|
||||||
@ -726,8 +788,8 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
engine cpu_engine = engine(engine::cpu, 0);
|
engine cpu_engine = engine(ENGINE_CPU, 0);
|
||||||
std::shared_ptr<eltwise_forward::primitive_desc> relu_fwd_pd;
|
std::shared_ptr<EltwiseFwdPd> relu_fwd_pd;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
float alpha_;
|
float alpha_;
|
||||||
@ -735,12 +797,13 @@ class MklReluGradOpBase : public OpKernel {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
class MklReluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||||
public:
|
public:
|
||||||
~MklReluOp() {}
|
~MklReluOp() {}
|
||||||
|
|
||||||
explicit MklReluOp(OpKernelConstruction* context)
|
explicit MklReluOp(OpKernelConstruction* context)
|
||||||
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {}
|
: MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
|
||||||
|
}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
const size_t src_index = 0; // index of src input tensor
|
const size_t src_index = 0; // index of src input tensor
|
||||||
@ -764,12 +827,14 @@ class MklReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
|
class MklReluGradOp
|
||||||
|
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||||
public:
|
public:
|
||||||
~MklReluGradOp() {}
|
~MklReluGradOp() {}
|
||||||
|
|
||||||
explicit MklReluGradOp(OpKernelConstruction* context)
|
explicit MklReluGradOp(OpKernelConstruction* context)
|
||||||
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {}
|
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
|
||||||
|
0.0f) {}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||||
@ -799,12 +864,12 @@ class MklReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
|
class MklEluOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_elu> {
|
||||||
public:
|
public:
|
||||||
~MklEluOp() {}
|
~MklEluOp() {}
|
||||||
|
|
||||||
explicit MklEluOp(OpKernelConstruction* context)
|
explicit MklEluOp(OpKernelConstruction* context)
|
||||||
: MklReluOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {}
|
: MklReluOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f, 0.0f) {}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
const size_t src_index = 0; // index of src input tensor
|
const size_t src_index = 0; // index of src input tensor
|
||||||
@ -832,12 +897,14 @@ class MklEluOp : public MklReluOpBase<Device, T, eltwise_elu> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
|
class MklEluGradOp
|
||||||
|
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu> {
|
||||||
public:
|
public:
|
||||||
~MklEluGradOp() {}
|
~MklEluGradOp() {}
|
||||||
|
|
||||||
explicit MklEluGradOp(OpKernelConstruction* context)
|
explicit MklEluGradOp(OpKernelConstruction* context)
|
||||||
: MklReluGradOpBase<Device, T, eltwise_elu>(context, 0.0f, 0.0f) {}
|
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_elu>(context, 0.0f,
|
||||||
|
0.0f) {}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||||
@ -872,12 +939,13 @@ class MklEluGradOp : public MklReluGradOpBase<Device, T, eltwise_elu> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
|
class MklTanhOp : public MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||||
public:
|
public:
|
||||||
~MklTanhOp() {}
|
~MklTanhOp() {}
|
||||||
|
|
||||||
explicit MklTanhOp(OpKernelConstruction* context)
|
explicit MklTanhOp(OpKernelConstruction* context)
|
||||||
: MklReluOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {}
|
: MklReluOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f, 0.0f) {
|
||||||
|
}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
const size_t src_index = 0; // index of src input tensor
|
const size_t src_index = 0; // index of src input tensor
|
||||||
@ -904,12 +972,14 @@ class MklTanhOp : public MklReluOpBase<Device, T, eltwise_tanh> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
|
class MklTanhGradOp
|
||||||
|
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh> {
|
||||||
public:
|
public:
|
||||||
~MklTanhGradOp() {}
|
~MklTanhGradOp() {}
|
||||||
|
|
||||||
explicit MklTanhGradOp(OpKernelConstruction* context)
|
explicit MklTanhGradOp(OpKernelConstruction* context)
|
||||||
: MklReluGradOpBase<Device, T, eltwise_tanh>(context, 0.0f, 0.0f) {}
|
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_tanh>(context, 0.0f,
|
||||||
|
0.0f) {}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
const size_t diff_dst_index = 0; // index of diff_dst input tensor
|
||||||
@ -943,12 +1013,13 @@ class MklTanhGradOp : public MklReluGradOpBase<Device, T, eltwise_tanh> {
|
|||||||
|
|
||||||
#define RELU6_UPPER_BOUND 6.0f
|
#define RELU6_UPPER_BOUND 6.0f
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
|
class MklRelu6Op
|
||||||
|
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
|
||||||
public:
|
public:
|
||||||
~MklRelu6Op() {}
|
~MklRelu6Op() {}
|
||||||
|
|
||||||
explicit MklRelu6Op(OpKernelConstruction* context)
|
explicit MklRelu6Op(OpKernelConstruction* context)
|
||||||
: MklReluOpBase<Device, T, eltwise_bounded_relu>(
|
: MklReluOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
|
||||||
context, RELU6_UPPER_BOUND, 0.0f) {}
|
context, RELU6_UPPER_BOUND, 0.0f) {}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
@ -973,12 +1044,12 @@ class MklRelu6Op : public MklReluOpBase<Device, T, eltwise_bounded_relu> {
|
|||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklRelu6GradOp
|
class MklRelu6GradOp
|
||||||
: public MklReluGradOpBase<Device, T, eltwise_bounded_relu> {
|
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu> {
|
||||||
public:
|
public:
|
||||||
~MklRelu6GradOp() {}
|
~MklRelu6GradOp() {}
|
||||||
|
|
||||||
explicit MklRelu6GradOp(OpKernelConstruction* context)
|
explicit MklRelu6GradOp(OpKernelConstruction* context)
|
||||||
: MklReluGradOpBase<Device, T, eltwise_bounded_relu>(
|
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_bounded_relu>(
|
||||||
context, RELU6_UPPER_BOUND, 0.0f) {}
|
context, RELU6_UPPER_BOUND, 0.0f) {}
|
||||||
|
|
||||||
virtual void Compute_Scalar(OpKernelContext* context) {
|
virtual void Compute_Scalar(OpKernelContext* context) {
|
||||||
@ -1007,12 +1078,13 @@ class MklRelu6GradOp
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
class MklLeakyReluOp
|
||||||
|
: public MklReluOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||||
public:
|
public:
|
||||||
~MklLeakyReluOp() {}
|
~MklLeakyReluOp() {}
|
||||||
|
|
||||||
explicit MklLeakyReluOp(OpKernelConstruction* context)
|
explicit MklLeakyReluOp(OpKernelConstruction* context)
|
||||||
: MklReluOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
|
: MklReluOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f, 0.0f) {
|
||||||
float alpha;
|
float alpha;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
|
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
@ -1044,12 +1116,14 @@ class MklLeakyReluOp : public MklReluOpBase<Device, T, eltwise_relu> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class MklLeakyReluGradOp : public MklReluGradOpBase<Device, T, eltwise_relu> {
|
class MklLeakyReluGradOp
|
||||||
|
: public MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu> {
|
||||||
public:
|
public:
|
||||||
~MklLeakyReluGradOp() {}
|
~MklLeakyReluGradOp() {}
|
||||||
|
|
||||||
explicit MklLeakyReluGradOp(OpKernelConstruction* context)
|
explicit MklLeakyReluGradOp(OpKernelConstruction* context)
|
||||||
: MklReluGradOpBase<Device, T, eltwise_relu>(context, 0.0f, 0.0f) {
|
: MklReluGradOpBase<Device, T, ALGORITHM::eltwise_relu>(context, 0.0f,
|
||||||
|
0.0f) {
|
||||||
float alpha;
|
float alpha;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
|
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha));
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
|
|||||||
@ -77,6 +77,7 @@ namespace tensorflow {
|
|||||||
memory::desc({dims}, MklDnnType<type>(), fm)
|
memory::desc({dims}, MklDnnType<type>(), fm)
|
||||||
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine
|
||||||
#define MEMORY_PRIMITIVE_DESC memory::desc
|
#define MEMORY_PRIMITIVE_DESC memory::desc
|
||||||
|
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) MEMORY_PRIMITIVE_DESC(md)
|
||||||
#define MKL_FMT_TAG mkl_fmt_tag
|
#define MKL_FMT_TAG mkl_fmt_tag
|
||||||
#define MKL_TENSOR_FORMAT MklTensorFormat
|
#define MKL_TENSOR_FORMAT MklTensorFormat
|
||||||
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED
|
||||||
@ -170,6 +171,8 @@ namespace tensorflow {
|
|||||||
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine)
|
||||||
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd
|
||||||
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
|
#define MEMORY_PRIMITIVE_DESC memory::primitive_desc
|
||||||
|
#define MEMORY_PD_CONSTRUCTOR_2_PARAMS(md, engine) \
|
||||||
|
MEMORY_PRIMITIVE_DESC(md, engine)
|
||||||
#define MKL_FMT_TAG tf_fmt
|
#define MKL_FMT_TAG tf_fmt
|
||||||
#define MKL_TENSOR_FORMAT memory::format
|
#define MKL_TENSOR_FORMAT memory::format
|
||||||
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user