Added changes for DNN 0.9 to softmax, identity_op, and lrn ops.

This commit is contained in:
Lakshay Tokas 2020-02-10 15:40:01 -08:00
parent 04f2870814
commit 66832a3986
3 changed files with 190 additions and 179 deletions

View File

@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@ -23,8 +24,6 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "mkldnn.hpp"
#include "tensorflow/core/util/mkl_util.h"
namespace tensorflow {
@ -64,4 +63,5 @@ TF_CALL_float(REGISTER_MKL_CPU);
TF_CALL_bfloat16(REGISTER_MKL_CPU);
#undef REGISTER_MKL_CPU
} // namespace tensorflow
#endif // INTEL_MKL

View File

@ -21,24 +21,26 @@ limitations under the License.
#ifdef INTEL_MKL
#define EIGEN_USE_THREADS
#include <unordered_map>
#include <vector>
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bounds_check.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#if !defined(IS_MOBILE_PLATFORM)
#include "tensorflow/core/util/work_sharder.h"
#endif
using mkldnn::lrn_across_channels;
using mkldnn::lrn_backward;
using mkldnn::lrn_forward;
using mkldnn::prop_kind;
@ -69,14 +71,14 @@ class MklLRNOp : public OpKernel {
public:
~MklLRNOp() {}
explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) {
explicit MklLRNOp(OpKernelConstruction* context)
: OpKernel(context), cpu_engine_(ENGINE_CPU, 0) {
int64 depth_radius64;
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
OP_REQUIRES(
context,
FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
errors::InvalidArgument("depth_radius = ", depth_radius64,
" larger than int max"));
OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
std::numeric_limits<int>::max()),
errors::InvalidArgument("depth_radius = ", depth_radius64,
" larger than int max"));
depth_radius_ = static_cast<size_t>(depth_radius64);
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
@ -85,6 +87,7 @@ class MklLRNOp : public OpKernel {
workspace_enabled_ = false;
OP_REQUIRES_OK(context,
context->GetAttr("workspace_enabled", &workspace_enabled_));
fwd_stream_.reset(new CPU_STREAM(cpu_engine_));
}
void Compute(OpKernelContext* context) override {
@ -92,7 +95,6 @@ class MklLRNOp : public OpKernel {
SanityCheckInputs(context);
if (!context->status().ok()) return;
auto cpu_engine = engine(engine::cpu, 0);
const Tensor& src_tensor = MklGetInput(context, kIdxInput);
MklDnnShape src_dnn_shape;
GetMklShape(context, kIdxInput, &src_dnn_shape);
@ -120,9 +122,9 @@ class MklLRNOp : public OpKernel {
// and we can enable the workspace
workspace_enabled_ = true;
MklDnnData<T> src_dnn_data(&cpu_engine);
MklDnnData<T> dst_dnn_data(&cpu_engine);
MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
MklDnnData<T> src_dnn_data(&cpu_engine_);
MklDnnData<T> dst_dnn_data(&cpu_engine_);
MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
TensorShape tf_output_shape = src_tensor.shape();
@ -134,39 +136,57 @@ class MklLRNOp : public OpKernel {
// and MKL-DNN performs normalization over Channel, we tell MKL-DNN
// that input is in NHWC layout with Channel being the last dimension.
src_dnn_data.SetUsrMem(src_md, &src_tensor);
src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
// output_dnn_data and workspace both have the same shape as input
// dst_dnn_data has the same shape as input.
dst_dnn_data.SetUsrMem(src_md);
dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc);
dst_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc);
// Create LRN primitive descriptor.
// Tensorflow's normalization semantics is across channels.
// MKL-DNN also supports normalization within channel.
auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels,
src_dnn_data.GetUsrMemDesc(),
kernel_size, new_alpha, beta_, bias_);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine);
auto lrn_desc = lrn_forward::desc(
prop_kind::forward, ALGORITHM::lrn_across_channels,
src_dnn_data.GetUsrMemDesc(), kernel_size, new_alpha, beta_, bias_);
auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine_);
// Allocate output_dnn_data tensor.
Tensor* output_tensor = nullptr;
memory::format input_format = src_dnn_shape.GetTfDataFormat();
auto input_format = src_dnn_shape.GetTfDataFormat();
AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format,
&output_tensor);
OP_REQUIRES_OK(context, context->status());
CHECK_NOTNULL(output_tensor);
DCHECK(output_tensor != nullptr);
dst_dnn_data.SetUsrMemDataHandle(output_tensor);
// Handle workspace required for MKL-DNN.
AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data);
OP_REQUIRES_OK(context, context->status());
PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data,
&workspace_dnn_data);
// Check for input reorder
src_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
net.push_back(lrn_forward(lrn_prim_desc));
std::vector<std::unordered_map<int, memory>> net_args;
net_args.push_back({{MKLDNN_ARG_SRC, src_dnn_data.GetOpMem()},
{MKLDNN_ARG_WORKSPACE, workspace_dnn_data.GetOpMem()},
{ MKLDNN_ARG_DST,
dst_dnn_data.GetOpMem() }});
net.push_back(lrn_forward(lrn_prim_desc));
net.at(0).execute(*fwd_stream_, net_args.at(0));
#else
net.push_back(lrn_forward(lrn_prim_desc, src_dnn_data.GetOpMem(),
workspace_dnn_data.GetOpMem(),
dst_dnn_data.GetOpMem()));
fwd_stream_->submit(net).wait();
#endif
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@ -174,33 +194,13 @@ class MklLRNOp : public OpKernel {
}
private:
void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc,
MklDnnData<T>* src_dnn_data,
MklDnnData<T>* dst_dnn_data,
MklDnnData<uint8>* wksp_dnn_data = nullptr) {
// Check for input reorder
src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
// Create pooling primitive and add it to net
std::vector<primitive> net;
if (wksp_dnn_data != nullptr) {
net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
wksp_dnn_data->GetOpMem(),
dst_dnn_data->GetOpMem()));
} else {
net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(),
dst_dnn_data->GetOpMem()));
}
stream(stream::kind::eager).submit(net).wait();
}
void AllocateOutputTensor(
OpKernelContext* context,
const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
const memory::dims output_dims_mkl_order,
const memory::format& output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc();
const MKL_TENSOR_FORMAT& output_tf_format, Tensor** output_tensor) {
DCHECK(output_tensor != nullptr);
MEMORY_PRIMITIVE_DESC dst_pd = lrn_fwd_prim_desc.PRIMITIVE_DESC_DST;
MklDnnShape output_mkl_shape;
// We only handle the case when the inputs and output are in Mkl format
@ -231,8 +231,7 @@ class MklLRNOp : public OpKernel {
auto in_shaped = input.shaped<T, 2>({nodes * batch, depth});
// Multiplying the input with the band matrix has the effect of reducing
// the
// correct patch along the depth.
// the correct patch along the depth.
Eigen::Tensor<T, 2, Eigen::RowMajor> multiplier(depth, depth);
GetBandMatrix<T>(depth, depth_radius_, &multiplier);
@ -242,7 +241,7 @@ class MklLRNOp : public OpKernel {
mkl_output_mkl_shape.SetDimensions(4);
AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data,
input.shape(), mkl_output_mkl_shape);
CHECK_NOTNULL(output_dnn_data);
DCHECK(output_dnn_data != nullptr);
Tensor* workspace_tensor = nullptr;
MklDnnShape workspace_mkl_shape;
@ -251,7 +250,7 @@ class MklLRNOp : public OpKernel {
workspace_tf_shape.AddDim(0);
AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
workspace_tf_shape, workspace_mkl_shape);
CHECK_NOTNULL(workspace_tensor);
DCHECK(workspace_tensor);
auto out_shaped = output_dnn_data->shaped<T, 2>({nodes * batch, depth});
Eigen::array<DimPair, 1> dims = {{DimPair(1, 0)}};
@ -271,10 +270,10 @@ class MklLRNOp : public OpKernel {
OpKernelContext* context,
const lrn_forward::primitive_desc& lrn_fwd_prim_desc,
MklDnnData<uint8>* dnn_data_wksp) {
CHECK_NOTNULL(dnn_data_wksp);
DCHECK(dnn_data_wksp != nullptr);
Tensor* workspace_tensor = nullptr;
memory::primitive_desc workspace_pd =
lrn_fwd_prim_desc.workspace_primitive_desc();
MEMORY_PRIMITIVE_DESC workspace_pd =
lrn_fwd_prim_desc.PRIMITIVE_DESC_WORKSPACE;
size_t workspace_bytes = workspace_pd.get_size();
MklDnnShape workspace_mkl_shape;
// the workspace tensor is a uint8 tensor that has
@ -284,7 +283,7 @@ class MklLRNOp : public OpKernel {
workspace_tf_shape.AddDim(workspace_bytes);
AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor,
workspace_tf_shape, workspace_mkl_shape);
CHECK_NOTNULL(workspace_tensor);
DCHECK(workspace_tensor != nullptr);
dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor);
}
@ -295,16 +294,14 @@ class MklLRNOp : public OpKernel {
if (src_dnn_shape.IsMklTensor()) {
OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4,
errors::InvalidArgument("input must be 4-dimensional"));
OP_REQUIRES(context,
FastBoundsCheck(src_tensor.NumElements(),
std::numeric_limits<int>::max()),
OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
std::numeric_limits<int>::max()),
errors::InvalidArgument("argument to LRN too large"));
} else {
OP_REQUIRES(context, src_tensor.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional"));
OP_REQUIRES(context,
FastBoundsCheck(src_tensor.NumElements(),
std::numeric_limits<int>::max()),
OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(),
std::numeric_limits<int>::max()),
errors::InvalidArgument("argument to LRN too large"));
}
}
@ -316,19 +313,21 @@ class MklLRNOp : public OpKernel {
float bias_;
float alpha_;
float beta_;
engine cpu_engine_;
std::shared_ptr<stream> fwd_stream_;
};
template <typename T>
class MklLRNGradOp : public OpKernel {
public:
explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) {
explicit MklLRNGradOp(OpKernelConstruction* context)
: OpKernel(context), cpu_engine_(ENGINE_CPU, 0) {
int64 depth_radius64;
OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64));
OP_REQUIRES(
context,
FastBoundsCheck(depth_radius64, std::numeric_limits<int>::max()),
errors::InvalidArgument("depth_radius = ", depth_radius64,
" larger than int max"));
OP_REQUIRES(context, FastBoundsCheck(depth_radius64,
std::numeric_limits<int>::max()),
errors::InvalidArgument("depth_radius = ", depth_radius64,
" larger than int max"));
depth_radius_ = static_cast<int>(depth_radius64);
OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_));
OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_));
@ -336,6 +335,7 @@ class MklLRNGradOp : public OpKernel {
workspace_enabled_ = false;
OP_REQUIRES_OK(context,
context->GetAttr("workspace_enabled", &workspace_enabled_));
bwd_stream_.reset(new CPU_STREAM(cpu_engine_));
}
void Compute(OpKernelContext* context) override {
@ -343,11 +343,10 @@ class MklLRNGradOp : public OpKernel {
SanityCheckInputs(context);
if (!context->status().ok()) return;
auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> input_grad_dnn_data(&cpu_engine);
MklDnnData<T> orig_input_dnn_data(&cpu_engine);
MklDnnData<T> orig_output_dnn_data(&cpu_engine);
MklDnnData<T> output_dnn_data(&cpu_engine);
MklDnnData<T> input_grad_dnn_data(&cpu_engine_);
MklDnnData<T> orig_input_dnn_data(&cpu_engine_);
MklDnnData<T> orig_output_dnn_data(&cpu_engine_);
MklDnnData<T> output_dnn_data(&cpu_engine_);
MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape,
orig_output_dnn_shape;
@ -389,11 +388,11 @@ class MklLRNGradOp : public OpKernel {
memory::dims orig_input_dims =
orig_input_dnn_shape.GetSizesAsMklDnnDims();
orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor);
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
// output_dnn_data has the same shape as original input
output_dnn_data.SetUsrMem(orig_input_md);
output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc);
output_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
// MKL-DNN has a notion of kernel_size and not depth_radius.
int kernel_size = 2 * depth_radius_ + 1;
@ -402,42 +401,61 @@ class MklLRNGradOp : public OpKernel {
// Create LRN backward primitive descriptor. It requires LRN forward
// primitive descriptor also.
auto lrn_fwd_desc = lrn_forward::desc(
prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size,
new_alpha, beta_, bias_);
auto lrn_fwd_prim_desc =
lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine);
auto lrn_bwd_desc = lrn_backward::desc(
lrn_across_channels, original_output_md, target_diff_dst_md,
prop_kind::forward, ALGORITHM::lrn_across_channels, orig_input_md,
kernel_size, new_alpha, beta_, bias_);
auto lrn_fwd_prim_desc =
lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine_);
auto lrn_bwd_desc = lrn_backward::desc(
ALGORITHM::lrn_across_channels, original_output_md,
target_diff_dst_md, kernel_size, new_alpha, beta_, bias_);
auto lrn_bwd_prim_desc = lrn_backward::primitive_desc(
lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc);
lrn_bwd_desc, cpu_engine_, lrn_fwd_prim_desc);
Tensor* output_tensor = nullptr;
memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
auto orig_input_format = orig_input_dnn_shape.GetTfDataFormat();
AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims,
orig_input_format, &output_tensor);
OP_REQUIRES_OK(context, context->status());
CHECK_NOTNULL(output_tensor);
DCHECK(output_tensor != nullptr);
output_dnn_data.SetUsrMemDataHandle(output_tensor);
// Create LRN primitive and add it to the net
// At this point, workspace is enabled, so we don't need
// to check. Pass input workspace to LRN backward primitive.
const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace);
MklDnnData<uint8> workspace_dnn_data(&cpu_engine);
MklDnnData<uint8> workspace_dnn_data(&cpu_engine_);
ConfigureWorkspace(workspace_tensor,
lrn_fwd_prim_desc.workspace_primitive_desc(),
lrn_fwd_prim_desc.PRIMITIVE_DESC_WORKSPACE,
&workspace_dnn_data);
PrepareAndExecuteNet(
lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data,
&input_grad_dnn_data, &output_dnn_data,
memory::primitive_desc(target_diff_dst_md, cpu_engine),
&workspace_dnn_data);
// Check for input reordering on the diff dst input
input_grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
lrn_bwd_prim_desc.PRIMITIVE_DESC_DIFF_DST, cpu_engine_));
// Check for input reordering on the original input
orig_input_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA(
lrn_fwd_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_));
std::vector<primitive> net;
#ifdef ENABLE_MKLDNN_V1
std::vector<std::unordered_map<int, memory>> net_args;
net.push_back(lrn_backward(lrn_bwd_prim_desc));
net_args.push_back({{MKLDNN_ARG_SRC, orig_input_dnn_data.GetOpMem()},
{MKLDNN_ARG_DIFF_DST, input_grad_dnn_data.GetOpMem()},
{ MKLDNN_ARG_DST,
output_dnn_data.GetOpMem() }});
net.push_back(lrn_backward(lrn_bwd_prim_desc));
net.at(0).execute(*bwd_stream_, net_args.at(0));
#else
net.push_back(lrn_backward(
lrn_bwd_prim_desc, orig_input_dnn_data.GetOpMem(),
input_grad_dnn_data.GetOpMem(), output_dnn_data.GetOpMem()));
bwd_stream_->submit(net).wait();
#endif
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@ -448,10 +466,9 @@ class MklLRNGradOp : public OpKernel {
OpKernelContext* context,
const lrn_backward::primitive_desc& lrn_bkwd_prim_desc,
const memory::dims output_dims_mkl_order,
const memory::format& output_tf_format, Tensor** output_tensor) {
CHECK_NOTNULL(output_tensor);
memory::primitive_desc dst_pd =
lrn_bkwd_prim_desc.diff_src_primitive_desc();
const MKL_TENSOR_FORMAT& output_tf_format, Tensor** output_tensor) {
DCHECK(output_tensor != nullptr);
MEMORY_PRIMITIVE_DESC dst_pd = lrn_bkwd_prim_desc.PRIMITIVE_DESC_DIFF_SRC;
MklDnnShape output_mkl_shape;
// We assume that all outputs at this point are MKL Tensors
@ -472,56 +489,28 @@ class MklLRNGradOp : public OpKernel {
memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor,
const MklDnnShape& input_grad_dnn_shape,
MklDnnData<T>* input_grad_dnn_data) {
CHECK_NOTNULL(input_grad_dnn_data);
DCHECK(input_grad_dnn_data != nullptr);
// This shouldn't be necessary at this point, but just in case
CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true);
DCHECK(input_grad_dnn_shape.IsMklTensor() == true);
memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout();
memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims();
input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor);
input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc);
input_grad_dnn_data->SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc);
return input_grad_md;
}
void PrepareAndExecuteNet(
const lrn_backward::primitive_desc& lrn_bkwd_desc,
const lrn_forward::primitive_desc& lrn_fwd_desc,
MklDnnData<T>* src_dnn_data, MklDnnData<T>* input_gradient_diff_dst,
MklDnnData<T>* output_diff_src,
const memory::primitive_desc& target_diff_dst_pd,
const MklDnnData<uint8>* workspace_dnn_data = nullptr) {
// Check for input reordering on the diff dst input
input_gradient_diff_dst->CheckReorderToOpMem(
lrn_bkwd_desc.diff_dst_primitive_desc());
// Check for input reordering on the original input
src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc());
// Create pooling primitive and add it to net
std::vector<primitive> net;
if (nullptr == workspace_dnn_data) {
net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
input_gradient_diff_dst->GetOpMem(),
output_diff_src->GetOpMem()));
} else {
net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(),
input_gradient_diff_dst->GetOpMem(),
workspace_dnn_data->GetOpMem(),
output_diff_src->GetOpMem()));
}
stream(stream::kind::eager).submit(net).wait();
}
void ConfigureWorkspace(const Tensor& workspace_tensor,
memory::primitive_desc workspace_pd,
MEMORY_PRIMITIVE_DESC workspace_pd,
MklDnnData<uint8>* workspace_dnn_data) {
CHECK_NOTNULL(workspace_dnn_data);
DCHECK(workspace_dnn_data);
workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor);
}
// Fallback implementation - Taken from lrn_op.cc
// TODO(intelft) Check if we can use EigenLRNOp directly instead of making a
// copy.
// TODO(intel-tf) Check if we can use EigenLRNOp directly
// instead of making a copy.
void MklDefaultToEigen(OpKernelContext* context) {
Tensor input_gradient_tensor;
Tensor orig_input_tensor;
@ -676,6 +665,8 @@ class MklLRNGradOp : public OpKernel {
float bias_;
float alpha_;
float beta_;
engine cpu_engine_;
std::shared_ptr<stream> bwd_stream_;
};
#define REGISTER_MKL_LRN_CPU(T) \

View File

@ -14,17 +14,19 @@ limitations under the License.
==============================================================================*/
// See docs in ../ops/nn_ops.cc.
#ifdef INTEL_MKL
#include "mkldnn.hpp"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/util/mkl_types.h"
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/tensor_format.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
using mkldnn::prop_kind;
using mkldnn::softmax_forward;
@ -35,10 +37,10 @@ namespace tensorflow {
class MklSoftmaxParams {
public:
memory::dims src_dims;
memory::format src_fmt;
MKL_TENSOR_FORMAT src_fmt;
int axis;
MklSoftmaxParams(memory::dims src_dims, memory::format src_fmt, int axis)
MklSoftmaxParams(memory::dims src_dims, MKL_TENSOR_FORMAT src_fmt, int axis)
: src_dims(src_dims), src_fmt(src_fmt), axis(axis) {}
};
@ -46,8 +48,8 @@ template <typename T>
class MklSoftmaxPrimitive : public MklPrimitive {
public:
explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
: cpu_engine_(ENGINE_CPU, 0) {
context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_));
Setup(fwdParams);
}
@ -61,9 +63,18 @@ class MklSoftmaxPrimitive : public MklPrimitive {
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
#ifdef ENABLE_MKLDNN_V1
DCHECK_EQ(context_.fwd_primitives.size(),
context_.fwd_net_args.size());
for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) {
context_.fwd_primitives.at(i).execute(*context_.fwd_stream,
context_.fwd_net_args.at(i));
}
#else
context_.fwd_stream->submit(context_.fwd_primitives);
#endif
// After execution, set data handle back
// After execution, set data handle back.
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
}
@ -74,22 +85,23 @@ class MklSoftmaxPrimitive : public MklPrimitive {
private:
struct SoftmaxFwdContext {
// MKL-DNN memory
// MKL-DNN memory.
std::shared_ptr<memory> src_mem;
std::shared_ptr<memory> dst_mem;
// Primitive desc
// Primitive descriptor.
std::shared_ptr<mkldnn::softmax_forward::desc> fwd_desc;
// Memory desc
// Memory descriptor.
std::shared_ptr<memory::desc> src_md;
// Softmax primitive
// Softmax primitive.
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> fwd_pd;
std::shared_ptr<mkldnn::primitive> softmax_fwd;
std::shared_ptr<stream> fwd_stream;
std::vector<mkldnn::primitive> fwd_primitives;
std::vector<MemoryArgsMap> fwd_net_args;
SoftmaxFwdContext()
: src_mem(nullptr),
@ -103,25 +115,33 @@ class MklSoftmaxPrimitive : public MklPrimitive {
// Softmax forward primitive setup
void Setup(const MklSoftmaxParams& fwdParams) {
// Create memory descriptors for softmax data with specified format
context_.src_md.reset(new memory::desc({fwdParams.src_dims},
MklDnnType<T>(), fwdParams.src_fmt));
// Create memory descriptors for softmax data with specified format.
auto src_format = GET_TENSOR_FORMAT(fwdParams.src_fmt);
context_.src_md.reset(
new memory::desc({fwdParams.src_dims}, MklDnnType<T>(), src_format));
// Create a softmax
// Create softmax decriptor and primitive descriptor.
context_.fwd_desc.reset(new mkldnn::softmax_forward::desc(
prop_kind::forward_scoring, *context_.src_md, fwdParams.axis));
context_.fwd_pd.reset(new mkldnn::softmax_forward::primitive_desc(
*context_.fwd_desc, cpu_engine_));
// Create memory primitive based on dummy data
context_.src_mem.reset(
new memory({*context_.src_md, cpu_engine_}, DummyData));
context_.dst_mem.reset(
new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
// Create memory primitive based on dummy data.
context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD(
*context_.src_md, cpu_engine_, DummyData));
context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_PD(
context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData));
#ifdef ENABLE_MKLDNN_V1
// Create softmax primitive and add it to net
context_.softmax_fwd.reset(new mkldnn::softmax_forward(*context_.fwd_pd));
context_.fwd_net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem},
{ MKLDNN_ARG_DST,
*context_.dst_mem }});
#else
context_.softmax_fwd.reset(new mkldnn::softmax_forward(
*context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
#endif // ENABLE_MKLDNN_V1
context_.fwd_primitives.push_back(*context_.softmax_fwd);
}
@ -134,7 +154,7 @@ template <typename T>
class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
static MklSoftmaxPrimitive<T>* Get(const MklSoftmaxParams& fwdParams) {
// Get a softmax fwd primitive from the cached pool
// Get a softmax fwd primitive from the cached pool.
MklSoftmaxPrimitive<T>* softmax_forward =
static_cast<MklSoftmaxPrimitive<T>*>(
MklSoftmaxPrimitiveFactory<T>::GetInstance().GetSoftmaxFwd(
@ -189,15 +209,15 @@ class MklSoftmaxOp : public OpKernel {
void Compute(OpKernelContext* context) override {
try {
// src_tensor now points to the 0-th input of global data struct "context"
auto cpu_engine = engine(ENGINE_CPU, 0);
// src_tensor points to the 0-th input of global data struct "context".
size_t src_idx = 0;
const Tensor& src_tensor = MklGetInput(context, src_idx);
// Add: get MklShape
MklDnnShape src_mkl_shape;
GetMklShape(context, src_idx, &src_mkl_shape);
// src_dims is the dimension of src_tensor
// dim of the dst will also be same as src_dims
// src_dims is the dimension of src_tensor.
// Dim of the dst will also be same as src_dims.
auto src_tf_shape = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetTfShape()
: src_tensor.shape();
@ -211,7 +231,7 @@ class MklSoftmaxOp : public OpKernel {
src_dims = TFShapeToMklDnnDims(src_tf_shape);
axis = input_dims - 1;
}
memory::format layout_type;
MKL_TENSOR_FORMAT layout_type;
// In MKL, data format passed to mkl softmax op depends on dimension of
// the input tensor. Here "x" data format in MKL is used for 1 dim tensor,
// "nc" for 2 dim tensor, "tnc" for 3 dim tensor, "nchw" for 4 dim tensor,
@ -223,26 +243,26 @@ class MklSoftmaxOp : public OpKernel {
// dimension to do softmax.
switch (input_dims) {
case 1:
layout_type = memory::format::x;
layout_type = MKL_TENSOR_FORMAT_X;
break;
case 2:
layout_type = memory::format::nc;
layout_type = MKL_TENSOR_FORMAT_NC;
break;
case 3:
layout_type = memory::format::tnc;
layout_type = MKL_TENSOR_FORMAT_TNC;
break;
case 4:
if (src_mkl_shape.IsMklTensor()) {
layout_type = memory::format::nhwc;
layout_type = MKL_TENSOR_FORMAT_NHWC;
} else {
layout_type = memory::format::nchw;
layout_type = MKL_TENSOR_FORMAT_NCHW;
}
break;
case 5:
if (src_mkl_shape.IsMklTensor()) {
layout_type = memory::format::ndhwc;
layout_type = MKL_TENSOR_FORMAT_NDHWC;
} else {
layout_type = memory::format::ncdhw;
layout_type = MKL_TENSOR_FORMAT_NCDHW;
}
break;
default:
@ -254,21 +274,20 @@ class MklSoftmaxOp : public OpKernel {
// If input is in MKL layout, then simply get the format from input;
// otherwise, use TF layout defined before.
auto src_fmt = src_mkl_shape.IsMklTensor()
? static_cast<mkldnn::memory::format>(
src_mkl_shape.GetMklLayout().data.format)
? GET_FORMAT_FROM_SHAPE(src_mkl_shape)
: layout_type;
// Get a softmax fwd from primitive pool
// Get a softmax fwd primitive from primitive pool.
MklSoftmaxParams fwdParams(src_dims, src_fmt, axis);
MklSoftmaxPrimitive<T>* softmax_fwd =
MklSoftmaxPrimitiveFactory<T>::Get(fwdParams);
// Add output
// Prepare for creating output tensor.
Tensor* output_tensor = nullptr;
MklDnnShape output_mkl_shape;
TensorShape output_tf_shape; // shape of output TF tensor.
auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_primitive_desc();
auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->PRIMITIVE_DESC_DST;
// If input is MKL shape, output is also MKL shape.
// If input is TF shape, output is also TF shape.
@ -278,23 +297,23 @@ class MklSoftmaxOp : public OpKernel {
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(src_dims.size(), src_dims, layout_type);
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
} else { // then output is also TF shape
} else {
output_mkl_shape.SetMklTensor(false);
output_tf_shape = MklDnnDimsToTFShape(src_dims);
}
// Allocate output shape (MKL or TF based on the above)
// Allocate output tensor.
AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape,
output_mkl_shape);
const T* src_data = src_tensor.flat<T>().data();
T* dst_data = reinterpret_cast<T*>(output_tensor->flat<T>().data());
// Execute softmax
// Execute softmax primitive.
softmax_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
string(__FILE__) + ":" + std::to_string(__LINE__);
string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
string(e.message) + ", in file " + string(__FILE__) +
":" + std::to_string(__LINE__);
OP_REQUIRES_OK(
context,
errors::Aborted("Operation received an exception:", error_msg));
@ -311,6 +330,7 @@ class MklSoftmaxOp : public OpKernel {
.TypeConstraint<type>("T") \
.Label(mkl_op_registry::kMklLayoutDependentOpLabel), \
MklSoftmaxOp<CPUDevice, type>);
TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);
TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);