Added changes for DNN 0.9 to softmax, identity_op, and lrn ops.
This commit is contained in:
parent
04f2870814
commit
66832a3986
@ -16,6 +16,7 @@ limitations under the License.
|
||||
// See docs in ../ops/array_ops.cc.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -23,8 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -64,4 +63,5 @@ TF_CALL_float(REGISTER_MKL_CPU);
|
||||
TF_CALL_bfloat16(REGISTER_MKL_CPU);
|
||||
#undef REGISTER_MKL_CPU
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
@ -21,24 +21,26 @@ limitations under the License.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/bounds_check.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
#endif
|
||||
|
||||
using mkldnn::lrn_across_channels;
|
||||
using mkldnn::lrn_backward;
|
||||
using mkldnn::lrn_forward;
|
||||
using mkldnn::prop_kind;
|
||||
@ -69,14 +71,14 @@ class MklLRNOp : public OpKernel {
|
||||
public:
|
||||
~MklLRNOp() {}
|
||||
|
||||
explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
explicit MklLRNOp(OpKernelConstruction* context)
|
||||
: OpKernel(context), cpu_engine_(ENGINE_CPU, 0) {
|
||||
int64 depth_radius64;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("depth_radius = ", depth_radius64,
|
||||
" larger than int max"));
|
||||
OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("depth_radius = ", depth_radius64,
|
||||
" larger than int max"));
|
||||
depth_radius_ = static_cast<size_t>(depth_radius64);
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
|
||||
@ -85,6 +87,7 @@ class MklLRNOp : public OpKernel {
|
||||
workspace_enabled_ = false;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("workspace_enabled", &workspace_enabled_));
|
||||
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -92,7 +95,6 @@ class MklLRNOp : public OpKernel {
|
||||
SanityCheckInputs(context);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
const Tensor& src_tensor = MklGetInput(context, kIdxInput);
|
||||
MklDnnShape src_dnn_shape;
|
||||
GetMklShape(context, kIdxInput, &src_dnn_shape);
|
||||
@ -120,9 +122,9 @@ class MklLRNOp : public OpKernel {
|
||||
// and we can enable the workspace
|
||||
workspace_enabled_ = true;
|
||||
|
||||
MklDnnData<T> src_dnn_data(&cpu_engine);
|
||||
MklDnnData<T> dst_dnn_data(&cpu_engine);
|
||||
MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
|
||||
MklDnnData<T> src_dnn_data(&cpu_engine_);
|
||||
MklDnnData<T> dst_dnn_data(&cpu_engine_);
|
||||
MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
|
||||
|
||||
TensorShape tf_output_shape = src_tensor.shape();
|
||||
|
||||
@ -134,39 +136,57 @@ class MklLRNOp : public OpKernel {
|
||||
// and MKL-DNN performs normalization over Channel, we tell MKL-DNN
|
||||
// that input is in NHWC layout with Channel being the last dimension.
|
||||
src_dnn_data.SetUsrMem(src_md, &src_tensor);
|
||||
src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
|
||||
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
||||
|
||||
// output_dnn_data and workspace both have the same shape as input
|
||||
// dst_dnn_data has the same shape as input.
|
||||
dst_dnn_data.SetUsrMem(src_md);
|
||||
dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
|
||||
dst_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
|
||||
|
||||
// Create LRN primitive descriptor.
|
||||
// Tensorflow's normalization semantics is across channels.
|
||||
// MKL-DNN also supports normalization within channel.
|
||||
auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels,
|
||||
src_dnn_data.GetUsrMemDesc(),
|
||||
kernel_size, new_alpha, beta_, bias_);
|
||||
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
|
||||
auto lrn_desc = lrn_forward::desc(
|
||||
prop_kind::forward, ALGORITHM::lrn_across_channels,
|
||||
src_dnn_data.GetUsrMemDesc(), kernel_size, new_alpha, beta_, bias_);
|
||||
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine_);
|
||||
|
||||
// Allocate output_dnn_data tensor.
|
||||
Tensor* output_tensor = nullptr;
|
||||
memory::format input_format = src_dnn_shape.GetTfDataFormat();
|
||||
auto input_format = src_dnn_shape.GetTfDataFormat();
|
||||
AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format,
|
||||
&output_tensor);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
DCHECK(output_tensor != nullptr);
|
||||
dst_dnn_data.SetUsrMemDataHandle(output_tensor);
|
||||
|
||||
// Handle workspace required for MKL-DNN.
|
||||
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
|
||||
PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data,
|
||||
&workspace_dnn_data);
|
||||
// Check for input reorder
|
||||
src_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
net.push_back(lrn_forward(lrn_prim_desc));
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, src_dnn_data.GetOpMem()},
|
||||
{MKLDNN_ARG_WORKSPACE, workspace_dnn_data.GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
dst_dnn_data.GetOpMem() }});
|
||||
net.push_back(lrn_forward(lrn_prim_desc));
|
||||
net.at(0).execute(*fwd_stream_, net_args.at(0));
|
||||
#else
|
||||
net.push_back(lrn_forward(lrn_prim_desc, src_dnn_data.GetOpMem(),
|
||||
workspace_dnn_data.GetOpMem(),
|
||||
dst_dnn_data.GetOpMem()));
|
||||
fwd_stream_->submit(net).wait();
|
||||
#endif
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||
string(e.message) + ", in file " + string(__FILE__) +
|
||||
":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
@ -174,33 +194,13 @@ class MklLRNOp : public OpKernel {
|
||||
}
|
||||
|
||||
private:
|
||||
void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc,
|
||||
MklDnnData<T>* src_dnn_data,
|
||||
MklDnnData<T>* dst_dnn_data,
|
||||
MklDnnData<uint8>* wksp_dnn_data = nullptr) {
|
||||
// Check for input reorder
|
||||
src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
|
||||
|
||||
// Create pooling primitive and add it to net
|
||||
std::vector<primitive> net;
|
||||
if (wksp_dnn_data != nullptr) {
|
||||
net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
|
||||
wksp_dnn_data->GetOpMem(),
|
||||
dst_dnn_data->GetOpMem()));
|
||||
} else {
|
||||
net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
|
||||
dst_dnn_data->GetOpMem()));
|
||||
}
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
|
||||
void AllocateOutputTensor(
|
||||
OpKernelContext* context,
|
||||
const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
|
||||
const memory::dims output_dims_mkl_order,
|
||||
const memory::format& output_tf_format, Tensor** output_tensor) {
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc();
|
||||
const MKL_TENSOR_FORMAT& output_tf_format, Tensor** output_tensor) {
|
||||
DCHECK(output_tensor != nullptr);
|
||||
MEMORY_PRIMITIVE_DESC dst_pd = lrn_fwd_prim_desc.PRIMITIVE_DESC_DST;
|
||||
|
||||
MklDnnShape output_mkl_shape;
|
||||
// We only handle the case when the inputs and output are in Mkl format
|
||||
@ -231,8 +231,7 @@ class MklLRNOp : public OpKernel {
|
||||
|
||||
auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
|
||||
// Multiplying the input with the band matrix has the effect of reducing
|
||||
// the
|
||||
// correct patch along the depth.
|
||||
// the correct patch along the depth.
|
||||
Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
|
||||
GetBandMatrix<T>(depth, depth_radius_, &multiplier);
|
||||
|
||||
@ -242,7 +241,7 @@ class MklLRNOp : public OpKernel {
|
||||
mkl_output_mkl_shape.SetDimensions(4);
|
||||
AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
|
||||
input.shape(), mkl_output_mkl_shape);
|
||||
CHECK_NOTNULL(output_dnn_data);
|
||||
DCHECK(output_dnn_data != nullptr);
|
||||
|
||||
Tensor* workspace_tensor = nullptr;
|
||||
MklDnnShape workspace_mkl_shape;
|
||||
@ -251,7 +250,7 @@ class MklLRNOp : public OpKernel {
|
||||
workspace_tf_shape.AddDim(0);
|
||||
AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
|
||||
workspace_tf_shape, workspace_mkl_shape);
|
||||
CHECK_NOTNULL(workspace_tensor);
|
||||
DCHECK(workspace_tensor);
|
||||
|
||||
auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
|
||||
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
|
||||
@ -271,10 +270,10 @@ class MklLRNOp : public OpKernel {
|
||||
OpKernelContext* context,
|
||||
const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
|
||||
MklDnnData<uint8>* dnn_data_wksp) {
|
||||
CHECK_NOTNULL(dnn_data_wksp);
|
||||
DCHECK(dnn_data_wksp != nullptr);
|
||||
Tensor* workspace_tensor = nullptr;
|
||||
memory::primitive_desc workspace_pd =
|
||||
lrn_fwd_prim_desc.workspace_primitive_desc();
|
||||
MEMORY_PRIMITIVE_DESC workspace_pd =
|
||||
lrn_fwd_prim_desc.PRIMITIVE_DESC_WORKSPACE;
|
||||
size_t workspace_bytes = workspace_pd.get_size();
|
||||
MklDnnShape workspace_mkl_shape;
|
||||
// the workspace tensor is a uint8 tensor that has
|
||||
@ -284,7 +283,7 @@ class MklLRNOp : public OpKernel {
|
||||
workspace_tf_shape.AddDim(workspace_bytes);
|
||||
AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
|
||||
workspace_tf_shape, workspace_mkl_shape);
|
||||
CHECK_NOTNULL(workspace_tensor);
|
||||
DCHECK(workspace_tensor != nullptr);
|
||||
dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
|
||||
}
|
||||
|
||||
@ -295,16 +294,14 @@ class MklLRNOp : public OpKernel {
|
||||
if (src_dnn_shape.IsMklTensor()) {
|
||||
OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional"));
|
||||
OP_REQUIRES(context,
|
||||
FastBoundsCheck(src_tensor.NumElements(),
|
||||
std::numeric_limits<int>::max()),
|
||||
OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("argument to LRN too large"));
|
||||
} else {
|
||||
OP_REQUIRES(context, src_tensor.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional"));
|
||||
OP_REQUIRES(context,
|
||||
FastBoundsCheck(src_tensor.NumElements(),
|
||||
std::numeric_limits<int>::max()),
|
||||
OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("argument to LRN too large"));
|
||||
}
|
||||
}
|
||||
@ -316,19 +313,21 @@ class MklLRNOp : public OpKernel {
|
||||
float bias_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
engine cpu_engine_;
|
||||
std::shared_ptr<stream> fwd_stream_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MklLRNGradOp : public OpKernel {
|
||||
public:
|
||||
explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
explicit MklLRNGradOp(OpKernelConstruction* context)
|
||||
: OpKernel(context), cpu_engine_(ENGINE_CPU, 0) {
|
||||
int64 depth_radius64;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("depth_radius = ", depth_radius64,
|
||||
" larger than int max"));
|
||||
OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("depth_radius = ", depth_radius64,
|
||||
" larger than int max"));
|
||||
depth_radius_ = static_cast<int>(depth_radius64);
|
||||
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
|
||||
@ -336,6 +335,7 @@ class MklLRNGradOp : public OpKernel {
|
||||
workspace_enabled_ = false;
|
||||
OP_REQUIRES_OK(context,
|
||||
context->GetAttr("workspace_enabled", &workspace_enabled_));
|
||||
bwd_stream_.reset(new CPU_STREAM(cpu_engine_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -343,11 +343,10 @@ class MklLRNGradOp : public OpKernel {
|
||||
SanityCheckInputs(context);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
MklDnnData<T> input_grad_dnn_data(&cpu_engine);
|
||||
MklDnnData<T> orig_input_dnn_data(&cpu_engine);
|
||||
MklDnnData<T> orig_output_dnn_data(&cpu_engine);
|
||||
MklDnnData<T> output_dnn_data(&cpu_engine);
|
||||
MklDnnData<T> input_grad_dnn_data(&cpu_engine_);
|
||||
MklDnnData<T> orig_input_dnn_data(&cpu_engine_);
|
||||
MklDnnData<T> orig_output_dnn_data(&cpu_engine_);
|
||||
MklDnnData<T> output_dnn_data(&cpu_engine_);
|
||||
|
||||
MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
|
||||
orig_output_dnn_shape;
|
||||
@ -389,11 +388,11 @@ class MklLRNGradOp : public OpKernel {
|
||||
memory::dims orig_input_dims =
|
||||
orig_input_dnn_shape.GetSizesAsMklDnnDims();
|
||||
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
|
||||
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
|
||||
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
||||
|
||||
// output_dnn_data has the same shape as original input
|
||||
output_dnn_data.SetUsrMem(orig_input_md);
|
||||
output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
|
||||
output_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
||||
|
||||
// MKL-DNN has a notion of kernel_size and not depth_radius.
|
||||
int kernel_size = 2 * depth_radius_ + 1;
|
||||
@ -402,42 +401,61 @@ class MklLRNGradOp : public OpKernel {
|
||||
// Create LRN backward primitive descriptor. It requires LRN forward
|
||||
// primitive descriptor also.
|
||||
auto lrn_fwd_desc = lrn_forward::desc(
|
||||
prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size,
|
||||
new_alpha, beta_, bias_);
|
||||
auto lrn_fwd_prim_desc =
|
||||
lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine);
|
||||
auto lrn_bwd_desc = lrn_backward::desc(
|
||||
lrn_across_channels, original_output_md, target_diff_dst_md,
|
||||
prop_kind::forward, ALGORITHM::lrn_across_channels, orig_input_md,
|
||||
kernel_size, new_alpha, beta_, bias_);
|
||||
auto lrn_fwd_prim_desc =
|
||||
lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine_);
|
||||
auto lrn_bwd_desc = lrn_backward::desc(
|
||||
ALGORITHM::lrn_across_channels, original_output_md,
|
||||
target_diff_dst_md, kernel_size, new_alpha, beta_, bias_);
|
||||
auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(
|
||||
lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc);
|
||||
lrn_bwd_desc, cpu_engine_, lrn_fwd_prim_desc);
|
||||
|
||||
Tensor* output_tensor = nullptr;
|
||||
memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
|
||||
auto orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
|
||||
AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims,
|
||||
orig_input_format, &output_tensor);
|
||||
OP_REQUIRES_OK(context, context->status());
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
DCHECK(output_tensor != nullptr);
|
||||
output_dnn_data.SetUsrMemDataHandle(output_tensor);
|
||||
|
||||
// Create LRN primitive and add it to the net
|
||||
// At this point, workspace is enabled, so we don't need
|
||||
// to check. Pass input workspace to LRN backward primitive.
|
||||
const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
|
||||
MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
|
||||
MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
|
||||
ConfigureWorkspace(workspace_tensor,
|
||||
lrn_fwd_prim_desc.workspace_primitive_desc(),
|
||||
lrn_fwd_prim_desc.PRIMITIVE_DESC_WORKSPACE,
|
||||
&workspace_dnn_data);
|
||||
|
||||
PrepareAndExecuteNet(
|
||||
lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data,
|
||||
&input_grad_dnn_data, &output_dnn_data,
|
||||
memory::primitive_desc(target_diff_dst_md, cpu_engine),
|
||||
&workspace_dnn_data);
|
||||
// Check for input reordering on the diff dst input
|
||||
input_grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
lrn_bwd_prim_desc.PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
|
||||
|
||||
// Check for input reordering on the original input
|
||||
orig_input_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
|
||||
lrn_fwd_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
|
||||
|
||||
std::vector<primitive> net;
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
std::vector<std::unordered_map<int, memory>> net_args;
|
||||
net.push_back(lrn_backward(lrn_bwd_prim_desc));
|
||||
net_args.push_back({{MKLDNN_ARG_SRC, orig_input_dnn_data.GetOpMem()},
|
||||
{MKLDNN_ARG_DIFF_DST, input_grad_dnn_data.GetOpMem()},
|
||||
{ MKLDNN_ARG_DST,
|
||||
output_dnn_data.GetOpMem() }});
|
||||
net.push_back(lrn_backward(lrn_bwd_prim_desc));
|
||||
net.at(0).execute(*bwd_stream_, net_args.at(0));
|
||||
#else
|
||||
net.push_back(lrn_backward(
|
||||
lrn_bwd_prim_desc, orig_input_dnn_data.GetOpMem(),
|
||||
input_grad_dnn_data.GetOpMem(), output_dnn_data.GetOpMem()));
|
||||
bwd_stream_->submit(net).wait();
|
||||
#endif
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||
string(e.message) + ", in file " + string(__FILE__) +
|
||||
":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
@ -448,10 +466,9 @@ class MklLRNGradOp : public OpKernel {
|
||||
OpKernelContext* context,
|
||||
const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
|
||||
const memory::dims output_dims_mkl_order,
|
||||
const memory::format& output_tf_format, Tensor** output_tensor) {
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
memory::primitive_desc dst_pd =
|
||||
lrn_bkwd_prim_desc.diff_src_primitive_desc();
|
||||
const MKL_TENSOR_FORMAT& output_tf_format, Tensor** output_tensor) {
|
||||
DCHECK(output_tensor != nullptr);
|
||||
MEMORY_PRIMITIVE_DESC dst_pd = lrn_bkwd_prim_desc.PRIMITIVE_DESC_DIFF_SRC;
|
||||
MklDnnShape output_mkl_shape;
|
||||
|
||||
// We assume that all outputs at this point are MKL Tensors
|
||||
@ -472,56 +489,28 @@ class MklLRNGradOp : public OpKernel {
|
||||
memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
|
||||
const MklDnnShape& input_grad_dnn_shape,
|
||||
MklDnnData<T>* input_grad_dnn_data) {
|
||||
CHECK_NOTNULL(input_grad_dnn_data);
|
||||
DCHECK(input_grad_dnn_data != nullptr);
|
||||
// This shouldn't be necessary at this point, but just in case
|
||||
CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true);
|
||||
DCHECK(input_grad_dnn_shape.IsMklTensor() == true);
|
||||
|
||||
memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
|
||||
memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims();
|
||||
input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
|
||||
input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc);
|
||||
input_grad_dnn_data->SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
|
||||
return input_grad_md;
|
||||
}
|
||||
|
||||
void PrepareAndExecuteNet(
|
||||
const lrn_backward::primitive_desc& lrn_bkwd_desc,
|
||||
const lrn_forward::primitive_desc& lrn_fwd_desc,
|
||||
MklDnnData<T>* src_dnn_data, MklDnnData<T>* input_gradient_diff_dst,
|
||||
MklDnnData<T>* output_diff_src,
|
||||
const memory::primitive_desc& target_diff_dst_pd,
|
||||
const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
|
||||
// Check for input reordering on the diff dst input
|
||||
input_gradient_diff_dst->CheckReorderToOpMem(
|
||||
lrn_bkwd_desc.diff_dst_primitive_desc());
|
||||
|
||||
// Check for input reordering on the original input
|
||||
src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
|
||||
// Create pooling primitive and add it to net
|
||||
std::vector<primitive> net;
|
||||
if (nullptr == workspace_dnn_data) {
|
||||
net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
|
||||
input_gradient_diff_dst->GetOpMem(),
|
||||
output_diff_src->GetOpMem()));
|
||||
} else {
|
||||
net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
|
||||
input_gradient_diff_dst->GetOpMem(),
|
||||
workspace_dnn_data->GetOpMem(),
|
||||
output_diff_src->GetOpMem()));
|
||||
}
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
|
||||
void ConfigureWorkspace(const Tensor& workspace_tensor,
|
||||
memory::primitive_desc workspace_pd,
|
||||
MEMORY_PRIMITIVE_DESC workspace_pd,
|
||||
MklDnnData<uint8>* workspace_dnn_data) {
|
||||
CHECK_NOTNULL(workspace_dnn_data);
|
||||
DCHECK(workspace_dnn_data);
|
||||
|
||||
workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
|
||||
}
|
||||
|
||||
// Fallback implementation - Taken from lrn_op.cc
|
||||
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
|
||||
// copy.
|
||||
// TODO(intel-tf) Check if we can use EigenLRNOp directly
|
||||
// instead of making a copy.
|
||||
void MklDefaultToEigen(OpKernelContext* context) {
|
||||
Tensor input_gradient_tensor;
|
||||
Tensor orig_input_tensor;
|
||||
@ -676,6 +665,8 @@ class MklLRNGradOp : public OpKernel {
|
||||
float bias_;
|
||||
float alpha_;
|
||||
float beta_;
|
||||
engine cpu_engine_;
|
||||
std::shared_ptr<stream> bwd_stream_;
|
||||
};
|
||||
|
||||
#define REGISTER_MKL_LRN_CPU(T) \
|
||||
|
@ -14,17 +14,19 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "mkldnn.hpp"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/util/mkl_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
using mkldnn::prop_kind;
|
||||
using mkldnn::softmax_forward;
|
||||
@ -35,10 +37,10 @@ namespace tensorflow {
|
||||
class MklSoftmaxParams {
|
||||
public:
|
||||
memory::dims src_dims;
|
||||
memory::format src_fmt;
|
||||
MKL_TENSOR_FORMAT src_fmt;
|
||||
int axis;
|
||||
|
||||
MklSoftmaxParams(memory::dims src_dims, memory::format src_fmt, int axis)
|
||||
MklSoftmaxParams(memory::dims src_dims, MKL_TENSOR_FORMAT src_fmt, int axis)
|
||||
: src_dims(src_dims), src_fmt(src_fmt), axis(axis) {}
|
||||
};
|
||||
|
||||
@ -46,8 +48,8 @@ template <typename T>
|
||||
class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
public:
|
||||
explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
|
||||
: cpu_engine_(engine::cpu, 0) {
|
||||
context_.fwd_stream.reset(new stream(stream::kind::eager));
|
||||
: cpu_engine_(ENGINE_CPU, 0) {
|
||||
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
|
||||
Setup(fwdParams);
|
||||
}
|
||||
|
||||
@ -61,9 +63,18 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
static_cast<void*>(const_cast<T*>(src_data)));
|
||||
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
DCHECK_EQ(context_.fwd_primitives.size(),
|
||||
context_.fwd_net_args.size());
|
||||
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
|
||||
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
|
||||
context_.fwd_net_args.at(i));
|
||||
}
|
||||
#else
|
||||
context_.fwd_stream->submit(context_.fwd_primitives);
|
||||
#endif
|
||||
|
||||
// After execution, set data handle back
|
||||
// After execution, set data handle back.
|
||||
context_.src_mem->set_data_handle(DummyData);
|
||||
context_.dst_mem->set_data_handle(DummyData);
|
||||
}
|
||||
@ -74,22 +85,23 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
|
||||
private:
|
||||
struct SoftmaxFwdContext {
|
||||
// MKL-DNN memory
|
||||
// MKL-DNN memory.
|
||||
std::shared_ptr<memory> src_mem;
|
||||
std::shared_ptr<memory> dst_mem;
|
||||
|
||||
// Primitive desc
|
||||
// Primitive descriptor.
|
||||
std::shared_ptr<mkldnn::softmax_forward::desc> fwd_desc;
|
||||
|
||||
// Memory desc
|
||||
// Memory descriptor.
|
||||
std::shared_ptr<memory::desc> src_md;
|
||||
|
||||
// Softmax primitive
|
||||
// Softmax primitive.
|
||||
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd;
|
||||
std::shared_ptr<mkldnn::primitive> softmax_fwd;
|
||||
|
||||
std::shared_ptr<stream> fwd_stream;
|
||||
std::vector<mkldnn::primitive> fwd_primitives;
|
||||
std::vector<MemoryArgsMap> fwd_net_args;
|
||||
|
||||
SoftmaxFwdContext()
|
||||
: src_mem(nullptr),
|
||||
@ -103,25 +115,33 @@ class MklSoftmaxPrimitive : public MklPrimitive {
|
||||
|
||||
// Softmax forward primitive setup
|
||||
void Setup(const MklSoftmaxParams& fwdParams) {
|
||||
// Create memory descriptors for softmax data with specified format
|
||||
context_.src_md.reset(new memory::desc({fwdParams.src_dims},
|
||||
MklDnnType<T>(), fwdParams.src_fmt));
|
||||
// Create memory descriptors for softmax data with specified format.
|
||||
auto src_format = GET_TENSOR_FORMAT(fwdParams.src_fmt);
|
||||
context_.src_md.reset(
|
||||
new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), src_format));
|
||||
|
||||
// Create a softmax
|
||||
// Create softmax decriptor and primitive descriptor.
|
||||
context_.fwd_desc.reset(new mkldnn::softmax_forward::desc(
|
||||
prop_kind::forward_scoring, *context_.src_md, fwdParams.axis));
|
||||
context_.fwd_pd.reset(new mkldnn::softmax_forward::primitive_desc(
|
||||
*context_.fwd_desc, cpu_engine_));
|
||||
|
||||
// Create memory primitive based on dummy data
|
||||
context_.src_mem.reset(
|
||||
new memory({*context_.src_md, cpu_engine_}, DummyData));
|
||||
context_.dst_mem.reset(
|
||||
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
|
||||
// Create memory primitive based on dummy data.
|
||||
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD(
|
||||
*context_.src_md, cpu_engine_, DummyData));
|
||||
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_PD(
|
||||
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
|
||||
|
||||
#ifdef ENABLE_MKLDNN_V1
|
||||
// Create softmax primitive and add it to net
|
||||
context_.softmax_fwd.reset(new mkldnn::softmax_forward(*context_.fwd_pd));
|
||||
context_.fwd_net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
|
||||
{ MKLDNN_ARG_DST,
|
||||
*context_.dst_mem }});
|
||||
#else
|
||||
context_.softmax_fwd.reset(new mkldnn::softmax_forward(
|
||||
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
|
||||
#endif // ENABLE_MKLDNN_V1
|
||||
|
||||
context_.fwd_primitives.push_back(*context_.softmax_fwd);
|
||||
}
|
||||
@ -134,7 +154,7 @@ template <typename T>
|
||||
class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
public:
|
||||
static MklSoftmaxPrimitive<T>* Get(const MklSoftmaxParams& fwdParams) {
|
||||
// Get a softmax fwd primitive from the cached pool
|
||||
// Get a softmax fwd primitive from the cached pool.
|
||||
MklSoftmaxPrimitive<T>* softmax_forward =
|
||||
static_cast<MklSoftmaxPrimitive<T>*>(
|
||||
MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(
|
||||
@ -189,15 +209,15 @@ class MklSoftmaxOp : public OpKernel {
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
try {
|
||||
// src_tensor now points to the 0-th input of global data struct "context"
|
||||
auto cpu_engine = engine(ENGINE_CPU, 0);
|
||||
// src_tensor points to the 0-th input of global data struct "context".
|
||||
size_t src_idx = 0;
|
||||
const Tensor& src_tensor = MklGetInput(context, src_idx);
|
||||
// Add: get MklShape
|
||||
MklDnnShape src_mkl_shape;
|
||||
GetMklShape(context, src_idx, &src_mkl_shape);
|
||||
|
||||
// src_dims is the dimension of src_tensor
|
||||
// dim of the dst will also be same as src_dims
|
||||
// src_dims is the dimension of src_tensor.
|
||||
// Dim of the dst will also be same as src_dims.
|
||||
auto src_tf_shape = src_mkl_shape.IsMklTensor()
|
||||
? src_mkl_shape.GetTfShape()
|
||||
: src_tensor.shape();
|
||||
@ -211,7 +231,7 @@ class MklSoftmaxOp : public OpKernel {
|
||||
src_dims = TFShapeToMklDnnDims(src_tf_shape);
|
||||
axis = input_dims - 1;
|
||||
}
|
||||
memory::format layout_type;
|
||||
MKL_TENSOR_FORMAT layout_type;
|
||||
// In MKL, data format passed to mkl softmax op depends on dimension of
|
||||
// the input tensor. Here "x" data format in MKL is used for 1 dim tensor,
|
||||
// "nc" for 2 dim tensor, "tnc" for 3 dim tensor, "nchw" for 4 dim tensor,
|
||||
@ -223,26 +243,26 @@ class MklSoftmaxOp : public OpKernel {
|
||||
// dimension to do softmax.
|
||||
switch (input_dims) {
|
||||
case 1:
|
||||
layout_type = memory::format::x;
|
||||
layout_type = MKL_TENSOR_FORMAT_X;
|
||||
break;
|
||||
case 2:
|
||||
layout_type = memory::format::nc;
|
||||
layout_type = MKL_TENSOR_FORMAT_NC;
|
||||
break;
|
||||
case 3:
|
||||
layout_type = memory::format::tnc;
|
||||
layout_type = MKL_TENSOR_FORMAT_TNC;
|
||||
break;
|
||||
case 4:
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
layout_type = memory::format::nhwc;
|
||||
layout_type = MKL_TENSOR_FORMAT_NHWC;
|
||||
} else {
|
||||
layout_type = memory::format::nchw;
|
||||
layout_type = MKL_TENSOR_FORMAT_NCHW;
|
||||
}
|
||||
break;
|
||||
case 5:
|
||||
if (src_mkl_shape.IsMklTensor()) {
|
||||
layout_type = memory::format::ndhwc;
|
||||
layout_type = MKL_TENSOR_FORMAT_NDHWC;
|
||||
} else {
|
||||
layout_type = memory::format::ncdhw;
|
||||
layout_type = MKL_TENSOR_FORMAT_NCDHW;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
@ -254,21 +274,20 @@ class MklSoftmaxOp : public OpKernel {
|
||||
// If input is in MKL layout, then simply get the format from input;
|
||||
// otherwise, use TF layout defined before.
|
||||
auto src_fmt = src_mkl_shape.IsMklTensor()
|
||||
? static_cast<mkldnn::memory::format>(
|
||||
src_mkl_shape.GetMklLayout().data.format)
|
||||
? GET_FORMAT_FROM_SHAPE(src_mkl_shape)
|
||||
: layout_type;
|
||||
|
||||
// Get a softmax fwd from primitive pool
|
||||
// Get a softmax fwd primitive from primitive pool.
|
||||
MklSoftmaxParams fwdParams(src_dims, src_fmt, axis);
|
||||
MklSoftmaxPrimitive<T>* softmax_fwd =
|
||||
MklSoftmaxPrimitiveFactory<T>::Get(fwdParams);
|
||||
|
||||
// Add output
|
||||
// Prepare for creating output tensor.
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklDnnShape output_mkl_shape;
|
||||
TensorShape output_tf_shape; // shape of output TF tensor.
|
||||
|
||||
auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_primitive_desc();
|
||||
auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->PRIMITIVE_DESC_DST;
|
||||
|
||||
// If input is MKL shape, output is also MKL shape.
|
||||
// If input is TF shape, output is also TF shape.
|
||||
@ -278,23 +297,23 @@ class MklSoftmaxOp : public OpKernel {
|
||||
output_mkl_shape.SetElemType(MklDnnType<T>());
|
||||
output_mkl_shape.SetTfLayout(src_dims.size(), src_dims, layout_type);
|
||||
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
|
||||
} else { // then output is also TF shape
|
||||
} else {
|
||||
output_mkl_shape.SetMklTensor(false);
|
||||
output_tf_shape = MklDnnDimsToTFShape(src_dims);
|
||||
}
|
||||
// Allocate output shape (MKL or TF based on the above)
|
||||
// Allocate output tensor.
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape,
|
||||
output_mkl_shape);
|
||||
|
||||
const T* src_data = src_tensor.flat<T>().data();
|
||||
T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
|
||||
|
||||
// Execute softmax
|
||||
// Execute softmax primitive.
|
||||
softmax_fwd->Execute(src_data, dst_data);
|
||||
} catch (mkldnn::error& e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) + ", in file " +
|
||||
string(__FILE__) + ":" + std::to_string(__LINE__);
|
||||
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
|
||||
string(e.message) + ", in file " + string(__FILE__) +
|
||||
":" + std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(
|
||||
context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
@ -311,6 +330,7 @@ class MklSoftmaxOp : public OpKernel {
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
|
||||
MklSoftmaxOp<CPUDevice, type>);
|
||||
|
||||
TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user