diff --git a/tensorflow/core/kernels/mkl_identity_op.cc b/tensorflow/core/kernels/mkl_identity_op.cc index a2b6617ca61..7f6c255ac88 100644 --- a/tensorflow/core/kernels/mkl_identity_op.cc +++ b/tensorflow/core/kernels/mkl_identity_op.cc @@ -16,6 +16,7 @@ limitations under the License. // See docs in ../ops/array_ops.cc. #ifdef INTEL_MKL +#include "mkldnn.hpp" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" @@ -23,8 +24,6 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" - -#include "mkldnn.hpp" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { @@ -64,4 +63,5 @@ TF_CALL_float(REGISTER_MKL_CPU); TF_CALL_bfloat16(REGISTER_MKL_CPU); #undef REGISTER_MKL_CPU } // namespace tensorflow + #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_lrn_op.cc b/tensorflow/core/kernels/mkl_lrn_op.cc index 93df6e1ae99..2b7323d12af 100644 --- a/tensorflow/core/kernels/mkl_lrn_op.cc +++ b/tensorflow/core/kernels/mkl_lrn_op.cc @@ -21,24 +21,26 @@ limitations under the License. #ifdef INTEL_MKL #define EIGEN_USE_THREADS + +#include #include #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if !defined(IS_MOBILE_PLATFORM) #include "tensorflow/core/util/work_sharder.h" #endif -using mkldnn::lrn_across_channels; using mkldnn::lrn_backward; using mkldnn::lrn_forward; using mkldnn::prop_kind; @@ -69,14 +71,14 @@ class MklLRNOp : public OpKernel { public: ~MklLRNOp() {} - explicit MklLRNOp(OpKernelConstruction* context) : OpKernel(context) { + explicit MklLRNOp(OpKernelConstruction* context) + : OpKernel(context), cpu_engine_(ENGINE_CPU, 0) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES( - context, - FastBoundsCheck(depth_radius64, std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES(context, FastBoundsCheck(depth_radius64, + std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); @@ -85,6 +87,7 @@ class MklLRNOp : public OpKernel { workspace_enabled_ = false; OP_REQUIRES_OK(context, context->GetAttr("workspace_enabled", &workspace_enabled_)); + fwd_stream_.reset(new CPU_STREAM(cpu_engine_)); } void Compute(OpKernelContext* context) override { @@ -92,7 +95,6 @@ class MklLRNOp : public OpKernel { SanityCheckInputs(context); if (!context->status().ok()) return; - auto cpu_engine = engine(engine::cpu, 0); const Tensor& src_tensor = MklGetInput(context, kIdxInput); MklDnnShape src_dnn_shape; GetMklShape(context, kIdxInput, &src_dnn_shape); @@ -120,9 +122,9 @@ class MklLRNOp : public OpKernel { // and we can enable the workspace workspace_enabled_ = true; - MklDnnData src_dnn_data(&cpu_engine); - MklDnnData dst_dnn_data(&cpu_engine); - MklDnnData workspace_dnn_data(&cpu_engine); + MklDnnData src_dnn_data(&cpu_engine_); + MklDnnData dst_dnn_data(&cpu_engine_); + MklDnnData workspace_dnn_data(&cpu_engine_); TensorShape tf_output_shape = src_tensor.shape(); @@ -134,39 +136,57 @@ class MklLRNOp : public OpKernel { // and MKL-DNN performs normalization over Channel, we tell MKL-DNN // that input is in NHWC layout with Channel being the last dimension. src_dnn_data.SetUsrMem(src_md, &src_tensor); - src_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc); + src_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc); - // output_dnn_data and workspace both have the same shape as input + // dst_dnn_data has the same shape as input. dst_dnn_data.SetUsrMem(src_md); - dst_dnn_data.SetOpMemDesc(input_dims, memory::format::nhwc); + dst_dnn_data.SetOpMemDesc(input_dims, MEMORY_FORMAT::nhwc); // Create LRN primitive descriptor. // Tensorflow's normalization semantics is across channels. // MKL-DNN also supports normalization within channel. - auto lrn_desc = lrn_forward::desc(prop_kind::forward, lrn_across_channels, - src_dnn_data.GetUsrMemDesc(), - kernel_size, new_alpha, beta_, bias_); - auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine); + auto lrn_desc = lrn_forward::desc( + prop_kind::forward, ALGORITHM::lrn_across_channels, + src_dnn_data.GetUsrMemDesc(), kernel_size, new_alpha, beta_, bias_); + auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, cpu_engine_); // Allocate output_dnn_data tensor. Tensor* output_tensor = nullptr; - memory::format input_format = src_dnn_shape.GetTfDataFormat(); + auto input_format = src_dnn_shape.GetTfDataFormat(); AllocateOutputTensor(context, lrn_prim_desc, input_dims, input_format, &output_tensor); OP_REQUIRES_OK(context, context->status()); - CHECK_NOTNULL(output_tensor); + DCHECK(output_tensor != nullptr); dst_dnn_data.SetUsrMemDataHandle(output_tensor); // Handle workspace required for MKL-DNN. AllocateWorkspaceTensor(context, lrn_prim_desc, &workspace_dnn_data); OP_REQUIRES_OK(context, context->status()); - PrepareAndExecuteNet(lrn_prim_desc, &src_dnn_data, &dst_dnn_data, - &workspace_dnn_data); + // Check for input reorder + src_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + lrn_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_)); + + std::vector net; +#ifdef ENABLE_MKLDNN_V1 + net.push_back(lrn_forward(lrn_prim_desc)); + std::vector> net_args; + net_args.push_back({{MKLDNN_ARG_SRC, src_dnn_data.GetOpMem()}, + {MKLDNN_ARG_WORKSPACE, workspace_dnn_data.GetOpMem()}, + { MKLDNN_ARG_DST, + dst_dnn_data.GetOpMem() }}); + net.push_back(lrn_forward(lrn_prim_desc)); + net.at(0).execute(*fwd_stream_, net_args.at(0)); +#else + net.push_back(lrn_forward(lrn_prim_desc, src_dnn_data.GetOpMem(), + workspace_dnn_data.GetOpMem(), + dst_dnn_data.GetOpMem())); + fwd_stream_->submit(net).wait(); +#endif } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -174,33 +194,13 @@ class MklLRNOp : public OpKernel { } private: - void PrepareAndExecuteNet(const lrn_forward::primitive_desc& lrn_fwd_desc, - MklDnnData* src_dnn_data, - MklDnnData* dst_dnn_data, - MklDnnData* wksp_dnn_data = nullptr) { - // Check for input reorder - src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc()); - - // Create pooling primitive and add it to net - std::vector net; - if (wksp_dnn_data != nullptr) { - net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(), - wksp_dnn_data->GetOpMem(), - dst_dnn_data->GetOpMem())); - } else { - net.push_back(lrn_forward(lrn_fwd_desc, src_dnn_data->GetOpMem(), - dst_dnn_data->GetOpMem())); - } - stream(stream::kind::eager).submit(net).wait(); - } - void AllocateOutputTensor( OpKernelContext* context, const lrn_forward::primitive_desc& lrn_fwd_prim_desc, const memory::dims output_dims_mkl_order, - const memory::format& output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); - memory::primitive_desc dst_pd = lrn_fwd_prim_desc.dst_primitive_desc(); + const MKL_TENSOR_FORMAT& output_tf_format, Tensor** output_tensor) { + DCHECK(output_tensor != nullptr); + MEMORY_PRIMITIVE_DESC dst_pd = lrn_fwd_prim_desc.PRIMITIVE_DESC_DST; MklDnnShape output_mkl_shape; // We only handle the case when the inputs and output are in Mkl format @@ -231,8 +231,7 @@ class MklLRNOp : public OpKernel { auto in_shaped = input.shaped({nodes * batch, depth}); // Multiplying the input with the band matrix has the effect of reducing - // the - // correct patch along the depth. + // the correct patch along the depth. Eigen::Tensor multiplier(depth, depth); GetBandMatrix(depth, depth_radius_, &multiplier); @@ -242,7 +241,7 @@ class MklLRNOp : public OpKernel { mkl_output_mkl_shape.SetDimensions(4); AllocateOutputSetMklShape(context, kIdxOutput, &output_dnn_data, input.shape(), mkl_output_mkl_shape); - CHECK_NOTNULL(output_dnn_data); + DCHECK(output_dnn_data != nullptr); Tensor* workspace_tensor = nullptr; MklDnnShape workspace_mkl_shape; @@ -251,7 +250,7 @@ class MklLRNOp : public OpKernel { workspace_tf_shape.AddDim(0); AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor, workspace_tf_shape, workspace_mkl_shape); - CHECK_NOTNULL(workspace_tensor); + DCHECK(workspace_tensor); auto out_shaped = output_dnn_data->shaped({nodes * batch, depth}); Eigen::array dims = {{DimPair(1, 0)}}; @@ -271,10 +270,10 @@ class MklLRNOp : public OpKernel { OpKernelContext* context, const lrn_forward::primitive_desc& lrn_fwd_prim_desc, MklDnnData* dnn_data_wksp) { - CHECK_NOTNULL(dnn_data_wksp); + DCHECK(dnn_data_wksp != nullptr); Tensor* workspace_tensor = nullptr; - memory::primitive_desc workspace_pd = - lrn_fwd_prim_desc.workspace_primitive_desc(); + MEMORY_PRIMITIVE_DESC workspace_pd = + lrn_fwd_prim_desc.PRIMITIVE_DESC_WORKSPACE; size_t workspace_bytes = workspace_pd.get_size(); MklDnnShape workspace_mkl_shape; // the workspace tensor is a uint8 tensor that has @@ -284,7 +283,7 @@ class MklLRNOp : public OpKernel { workspace_tf_shape.AddDim(workspace_bytes); AllocateOutputSetMklShape(context, kIdxWorkspace, &workspace_tensor, workspace_tf_shape, workspace_mkl_shape); - CHECK_NOTNULL(workspace_tensor); + DCHECK(workspace_tensor != nullptr); dnn_data_wksp->SetUsrMem(workspace_pd, workspace_tensor); } @@ -295,16 +294,14 @@ class MklLRNOp : public OpKernel { if (src_dnn_shape.IsMklTensor()) { OP_REQUIRES(context, src_dnn_shape.GetDimension() == 4, errors::InvalidArgument("input must be 4-dimensional")); - OP_REQUIRES(context, - FastBoundsCheck(src_tensor.NumElements(), - std::numeric_limits::max()), + OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(), + std::numeric_limits::max()), errors::InvalidArgument("argument to LRN too large")); } else { OP_REQUIRES(context, src_tensor.dims() == 4, errors::InvalidArgument("input must be 4-dimensional")); - OP_REQUIRES(context, - FastBoundsCheck(src_tensor.NumElements(), - std::numeric_limits::max()), + OP_REQUIRES(context, FastBoundsCheck(src_tensor.NumElements(), + std::numeric_limits::max()), errors::InvalidArgument("argument to LRN too large")); } } @@ -316,19 +313,21 @@ class MklLRNOp : public OpKernel { float bias_; float alpha_; float beta_; + engine cpu_engine_; + std::shared_ptr fwd_stream_; }; template class MklLRNGradOp : public OpKernel { public: - explicit MklLRNGradOp(OpKernelConstruction* context) : OpKernel(context) { + explicit MklLRNGradOp(OpKernelConstruction* context) + : OpKernel(context), cpu_engine_(ENGINE_CPU, 0) { int64 depth_radius64; OP_REQUIRES_OK(context, context->GetAttr("depth_radius", &depth_radius64)); - OP_REQUIRES( - context, - FastBoundsCheck(depth_radius64, std::numeric_limits::max()), - errors::InvalidArgument("depth_radius = ", depth_radius64, - " larger than int max")); + OP_REQUIRES(context, FastBoundsCheck(depth_radius64, + std::numeric_limits::max()), + errors::InvalidArgument("depth_radius = ", depth_radius64, + " larger than int max")); depth_radius_ = static_cast(depth_radius64); OP_REQUIRES_OK(context, context->GetAttr("bias", &bias_)); OP_REQUIRES_OK(context, context->GetAttr("alpha", &alpha_)); @@ -336,6 +335,7 @@ class MklLRNGradOp : public OpKernel { workspace_enabled_ = false; OP_REQUIRES_OK(context, context->GetAttr("workspace_enabled", &workspace_enabled_)); + bwd_stream_.reset(new CPU_STREAM(cpu_engine_)); } void Compute(OpKernelContext* context) override { @@ -343,11 +343,10 @@ class MklLRNGradOp : public OpKernel { SanityCheckInputs(context); if (!context->status().ok()) return; - auto cpu_engine = engine(engine::cpu, 0); - MklDnnData input_grad_dnn_data(&cpu_engine); - MklDnnData orig_input_dnn_data(&cpu_engine); - MklDnnData orig_output_dnn_data(&cpu_engine); - MklDnnData output_dnn_data(&cpu_engine); + MklDnnData input_grad_dnn_data(&cpu_engine_); + MklDnnData orig_input_dnn_data(&cpu_engine_); + MklDnnData orig_output_dnn_data(&cpu_engine_); + MklDnnData output_dnn_data(&cpu_engine_); MklDnnShape input_grad_dnn_shape, orig_input_dnn_shape, orig_output_dnn_shape; @@ -389,11 +388,11 @@ class MklLRNGradOp : public OpKernel { memory::dims orig_input_dims = orig_input_dnn_shape.GetSizesAsMklDnnDims(); orig_input_dnn_data.SetUsrMem(orig_input_md, &orig_input_tensor); - orig_input_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); + orig_input_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc); // output_dnn_data has the same shape as original input output_dnn_data.SetUsrMem(orig_input_md); - output_dnn_data.SetOpMemDesc(orig_input_dims, memory::format::nhwc); + output_dnn_data.SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc); // MKL-DNN has a notion of kernel_size and not depth_radius. int kernel_size = 2 * depth_radius_ + 1; @@ -402,42 +401,61 @@ class MklLRNGradOp : public OpKernel { // Create LRN backward primitive descriptor. It requires LRN forward // primitive descriptor also. auto lrn_fwd_desc = lrn_forward::desc( - prop_kind::forward, lrn_across_channels, orig_input_md, kernel_size, - new_alpha, beta_, bias_); - auto lrn_fwd_prim_desc = - lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine); - auto lrn_bwd_desc = lrn_backward::desc( - lrn_across_channels, original_output_md, target_diff_dst_md, + prop_kind::forward, ALGORITHM::lrn_across_channels, orig_input_md, kernel_size, new_alpha, beta_, bias_); + auto lrn_fwd_prim_desc = + lrn_forward::primitive_desc(lrn_fwd_desc, cpu_engine_); + auto lrn_bwd_desc = lrn_backward::desc( + ALGORITHM::lrn_across_channels, original_output_md, + target_diff_dst_md, kernel_size, new_alpha, beta_, bias_); auto lrn_bwd_prim_desc = lrn_backward::primitive_desc( - lrn_bwd_desc, cpu_engine, lrn_fwd_prim_desc); + lrn_bwd_desc, cpu_engine_, lrn_fwd_prim_desc); Tensor* output_tensor = nullptr; - memory::format orig_input_format = orig_input_dnn_shape.GetTfDataFormat(); + auto orig_input_format = orig_input_dnn_shape.GetTfDataFormat(); AllocateOutputTensor(context, lrn_bwd_prim_desc, orig_input_dims, orig_input_format, &output_tensor); OP_REQUIRES_OK(context, context->status()); - CHECK_NOTNULL(output_tensor); + DCHECK(output_tensor != nullptr); output_dnn_data.SetUsrMemDataHandle(output_tensor); // Create LRN primitive and add it to the net // At this point, workspace is enabled, so we don't need // to check. Pass input workspace to LRN backward primitive. const Tensor& workspace_tensor = MklGetInput(context, kIdxWorkspace); - MklDnnData workspace_dnn_data(&cpu_engine); + MklDnnData workspace_dnn_data(&cpu_engine_); ConfigureWorkspace(workspace_tensor, - lrn_fwd_prim_desc.workspace_primitive_desc(), + lrn_fwd_prim_desc.PRIMITIVE_DESC_WORKSPACE, &workspace_dnn_data); - PrepareAndExecuteNet( - lrn_bwd_prim_desc, lrn_fwd_prim_desc, &orig_input_dnn_data, - &input_grad_dnn_data, &output_dnn_data, - memory::primitive_desc(target_diff_dst_md, cpu_engine), - &workspace_dnn_data); + // Check for input reordering on the diff dst input + input_grad_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + lrn_bwd_prim_desc.PRIMITIVE_DESC_DIFF_DST, cpu_engine_)); + + // Check for input reordering on the original input + orig_input_dnn_data.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + lrn_fwd_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_)); + + std::vector net; +#ifdef ENABLE_MKLDNN_V1 + std::vector> net_args; + net.push_back(lrn_backward(lrn_bwd_prim_desc)); + net_args.push_back({{MKLDNN_ARG_SRC, orig_input_dnn_data.GetOpMem()}, + {MKLDNN_ARG_DIFF_DST, input_grad_dnn_data.GetOpMem()}, + { MKLDNN_ARG_DST, + output_dnn_data.GetOpMem() }}); + net.push_back(lrn_backward(lrn_bwd_prim_desc)); + net.at(0).execute(*bwd_stream_, net_args.at(0)); +#else + net.push_back(lrn_backward( + lrn_bwd_prim_desc, orig_input_dnn_data.GetOpMem(), + input_grad_dnn_data.GetOpMem(), output_dnn_data.GetOpMem())); + bwd_stream_->submit(net).wait(); +#endif } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -448,10 +466,9 @@ class MklLRNGradOp : public OpKernel { OpKernelContext* context, const lrn_backward::primitive_desc& lrn_bkwd_prim_desc, const memory::dims output_dims_mkl_order, - const memory::format& output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); - memory::primitive_desc dst_pd = - lrn_bkwd_prim_desc.diff_src_primitive_desc(); + const MKL_TENSOR_FORMAT& output_tf_format, Tensor** output_tensor) { + DCHECK(output_tensor != nullptr); + MEMORY_PRIMITIVE_DESC dst_pd = lrn_bkwd_prim_desc.PRIMITIVE_DESC_DIFF_SRC; MklDnnShape output_mkl_shape; // We assume that all outputs at this point are MKL Tensors @@ -472,56 +489,28 @@ class MklLRNGradOp : public OpKernel { memory::desc ConfigureInputGradient(const Tensor& input_grad_tensor, const MklDnnShape& input_grad_dnn_shape, MklDnnData* input_grad_dnn_data) { - CHECK_NOTNULL(input_grad_dnn_data); + DCHECK(input_grad_dnn_data != nullptr); // This shouldn't be necessary at this point, but just in case - CHECK_EQ(input_grad_dnn_shape.IsMklTensor(), true); + DCHECK(input_grad_dnn_shape.IsMklTensor() == true); memory::desc input_grad_md = input_grad_dnn_shape.GetCurLayout(); memory::dims orig_input_dims = input_grad_dnn_shape.GetSizesAsMklDnnDims(); input_grad_dnn_data->SetUsrMem(input_grad_md, &input_grad_tensor); - input_grad_dnn_data->SetOpMemDesc(orig_input_dims, memory::format::nhwc); + input_grad_dnn_data->SetOpMemDesc(orig_input_dims, MEMORY_FORMAT::nhwc); return input_grad_md; } - void PrepareAndExecuteNet( - const lrn_backward::primitive_desc& lrn_bkwd_desc, - const lrn_forward::primitive_desc& lrn_fwd_desc, - MklDnnData* src_dnn_data, MklDnnData* input_gradient_diff_dst, - MklDnnData* output_diff_src, - const memory::primitive_desc& target_diff_dst_pd, - const MklDnnData* workspace_dnn_data = nullptr) { - // Check for input reordering on the diff dst input - input_gradient_diff_dst->CheckReorderToOpMem( - lrn_bkwd_desc.diff_dst_primitive_desc()); - - // Check for input reordering on the original input - src_dnn_data->CheckReorderToOpMem(lrn_fwd_desc.src_primitive_desc()); - // Create pooling primitive and add it to net - std::vector net; - if (nullptr == workspace_dnn_data) { - net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(), - input_gradient_diff_dst->GetOpMem(), - output_diff_src->GetOpMem())); - } else { - net.push_back(lrn_backward(lrn_bkwd_desc, src_dnn_data->GetOpMem(), - input_gradient_diff_dst->GetOpMem(), - workspace_dnn_data->GetOpMem(), - output_diff_src->GetOpMem())); - } - stream(stream::kind::eager).submit(net).wait(); - } - void ConfigureWorkspace(const Tensor& workspace_tensor, - memory::primitive_desc workspace_pd, + MEMORY_PRIMITIVE_DESC workspace_pd, MklDnnData* workspace_dnn_data) { - CHECK_NOTNULL(workspace_dnn_data); + DCHECK(workspace_dnn_data); workspace_dnn_data->SetUsrMem(workspace_pd, &workspace_tensor); } // Fallback implementation - Taken from lrn_op.cc - // TODO(intelft) Check if we can use EigenLRNOp directly instead of making a - // copy. + // TODO(intel-tf) Check if we can use EigenLRNOp directly + // instead of making a copy. void MklDefaultToEigen(OpKernelContext* context) { Tensor input_gradient_tensor; Tensor orig_input_tensor; @@ -676,6 +665,8 @@ class MklLRNGradOp : public OpKernel { float bias_; float alpha_; float beta_; + engine cpu_engine_; + std::shared_ptr bwd_stream_; }; #define REGISTER_MKL_LRN_CPU(T) \ diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index d3645b948dc..b9f8e590d0e 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -14,17 +14,19 @@ limitations under the License. ==============================================================================*/ // See docs in ../ops/nn_ops.cc. + #ifdef INTEL_MKL #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::prop_kind; using mkldnn::softmax_forward; @@ -35,10 +37,10 @@ namespace tensorflow { class MklSoftmaxParams { public: memory::dims src_dims; - memory::format src_fmt; + MKL_TENSOR_FORMAT src_fmt; int axis; - MklSoftmaxParams(memory::dims src_dims, memory::format src_fmt, int axis) + MklSoftmaxParams(memory::dims src_dims, MKL_TENSOR_FORMAT src_fmt, int axis) : src_dims(src_dims), src_fmt(src_fmt), axis(axis) {} }; @@ -46,8 +48,8 @@ template class MklSoftmaxPrimitive : public MklPrimitive { public: explicit MklSoftmaxPrimitive(const MklSoftmaxParams& fwdParams) - : cpu_engine_(engine::cpu, 0) { - context_.fwd_stream.reset(new stream(stream::kind::eager)); + : cpu_engine_(ENGINE_CPU, 0) { + context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); Setup(fwdParams); } @@ -61,9 +63,18 @@ class MklSoftmaxPrimitive : public MklPrimitive { static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); +#ifdef ENABLE_MKLDNN_V1 + DCHECK_EQ(context_.fwd_primitives.size(), + context_.fwd_net_args.size()); + for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { + context_.fwd_primitives.at(i).execute(*context_.fwd_stream, + context_.fwd_net_args.at(i)); + } +#else context_.fwd_stream->submit(context_.fwd_primitives); +#endif - // After execution, set data handle back + // After execution, set data handle back. context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); } @@ -74,22 +85,23 @@ class MklSoftmaxPrimitive : public MklPrimitive { private: struct SoftmaxFwdContext { - // MKL-DNN memory + // MKL-DNN memory. std::shared_ptr src_mem; std::shared_ptr dst_mem; - // Primitive desc + // Primitive descriptor. std::shared_ptr fwd_desc; - // Memory desc + // Memory descriptor. std::shared_ptr src_md; - // Softmax primitive + // Softmax primitive. std::shared_ptr fwd_pd; std::shared_ptr softmax_fwd; std::shared_ptr fwd_stream; std::vector fwd_primitives; + std::vector fwd_net_args; SoftmaxFwdContext() : src_mem(nullptr), @@ -103,25 +115,33 @@ class MklSoftmaxPrimitive : public MklPrimitive { // Softmax forward primitive setup void Setup(const MklSoftmaxParams& fwdParams) { - // Create memory descriptors for softmax data with specified format - context_.src_md.reset(new memory::desc({fwdParams.src_dims}, - MklDnnType(), fwdParams.src_fmt)); + // Create memory descriptors for softmax data with specified format. + auto src_format = GET_TENSOR_FORMAT(fwdParams.src_fmt); + context_.src_md.reset( + new memory::desc({fwdParams.src_dims}, MklDnnType(), src_format)); - // Create a softmax + // Create softmax decriptor and primitive descriptor. context_.fwd_desc.reset(new mkldnn::softmax_forward::desc( prop_kind::forward_scoring, *context_.src_md, fwdParams.axis)); context_.fwd_pd.reset(new mkldnn::softmax_forward::primitive_desc( *context_.fwd_desc, cpu_engine_)); - // Create memory primitive based on dummy data - context_.src_mem.reset( - new memory({*context_.src_md, cpu_engine_}, DummyData)); - context_.dst_mem.reset( - new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + // Create memory primitive based on dummy data. + context_.src_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD( + *context_.src_md, cpu_engine_, DummyData)); + context_.dst_mem.reset(new MEMORY_CONSTRUCTOR_PD( + context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData)); +#ifdef ENABLE_MKLDNN_V1 // Create softmax primitive and add it to net + context_.softmax_fwd.reset(new mkldnn::softmax_forward(*context_.fwd_pd)); + context_.fwd_net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); +#else context_.softmax_fwd.reset(new mkldnn::softmax_forward( *context_.fwd_pd, *context_.src_mem, *context_.dst_mem)); +#endif // ENABLE_MKLDNN_V1 context_.fwd_primitives.push_back(*context_.softmax_fwd); } @@ -134,7 +154,7 @@ template class MklSoftmaxPrimitiveFactory : public MklPrimitiveFactory { public: static MklSoftmaxPrimitive* Get(const MklSoftmaxParams& fwdParams) { - // Get a softmax fwd primitive from the cached pool + // Get a softmax fwd primitive from the cached pool. MklSoftmaxPrimitive* softmax_forward = static_cast*>( MklSoftmaxPrimitiveFactory::GetInstance().GetSoftmaxFwd( @@ -189,15 +209,15 @@ class MklSoftmaxOp : public OpKernel { void Compute(OpKernelContext* context) override { try { - // src_tensor now points to the 0-th input of global data struct "context" + auto cpu_engine = engine(ENGINE_CPU, 0); + // src_tensor points to the 0-th input of global data struct "context". size_t src_idx = 0; const Tensor& src_tensor = MklGetInput(context, src_idx); - // Add: get MklShape MklDnnShape src_mkl_shape; GetMklShape(context, src_idx, &src_mkl_shape); - // src_dims is the dimension of src_tensor - // dim of the dst will also be same as src_dims + // src_dims is the dimension of src_tensor. + // Dim of the dst will also be same as src_dims. auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() : src_tensor.shape(); @@ -211,7 +231,7 @@ class MklSoftmaxOp : public OpKernel { src_dims = TFShapeToMklDnnDims(src_tf_shape); axis = input_dims - 1; } - memory::format layout_type; + MKL_TENSOR_FORMAT layout_type; // In MKL, data format passed to mkl softmax op depends on dimension of // the input tensor. Here "x" data format in MKL is used for 1 dim tensor, // "nc" for 2 dim tensor, "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, @@ -223,26 +243,26 @@ class MklSoftmaxOp : public OpKernel { // dimension to do softmax. switch (input_dims) { case 1: - layout_type = memory::format::x; + layout_type = MKL_TENSOR_FORMAT_X; break; case 2: - layout_type = memory::format::nc; + layout_type = MKL_TENSOR_FORMAT_NC; break; case 3: - layout_type = memory::format::tnc; + layout_type = MKL_TENSOR_FORMAT_TNC; break; case 4: if (src_mkl_shape.IsMklTensor()) { - layout_type = memory::format::nhwc; + layout_type = MKL_TENSOR_FORMAT_NHWC; } else { - layout_type = memory::format::nchw; + layout_type = MKL_TENSOR_FORMAT_NCHW; } break; case 5: if (src_mkl_shape.IsMklTensor()) { - layout_type = memory::format::ndhwc; + layout_type = MKL_TENSOR_FORMAT_NDHWC; } else { - layout_type = memory::format::ncdhw; + layout_type = MKL_TENSOR_FORMAT_NCDHW; } break; default: @@ -254,21 +274,20 @@ class MklSoftmaxOp : public OpKernel { // If input is in MKL layout, then simply get the format from input; // otherwise, use TF layout defined before. auto src_fmt = src_mkl_shape.IsMklTensor() - ? static_cast( - src_mkl_shape.GetMklLayout().data.format) + ? GET_FORMAT_FROM_SHAPE(src_mkl_shape) : layout_type; - // Get a softmax fwd from primitive pool + // Get a softmax fwd primitive from primitive pool. MklSoftmaxParams fwdParams(src_dims, src_fmt, axis); MklSoftmaxPrimitive* softmax_fwd = MklSoftmaxPrimitiveFactory::Get(fwdParams); - // Add output + // Prepare for creating output tensor. Tensor* output_tensor = nullptr; MklDnnShape output_mkl_shape; TensorShape output_tf_shape; // shape of output TF tensor. - auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->dst_primitive_desc(); + auto dst_pd = softmax_fwd->GetSoftmaxFwdPd()->PRIMITIVE_DESC_DST; // If input is MKL shape, output is also MKL shape. // If input is TF shape, output is also TF shape. @@ -278,23 +297,23 @@ class MklSoftmaxOp : public OpKernel { output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(src_dims.size(), src_dims, layout_type); output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T))); - } else { // then output is also TF shape + } else { output_mkl_shape.SetMklTensor(false); output_tf_shape = MklDnnDimsToTFShape(src_dims); } - // Allocate output shape (MKL or TF based on the above) + // Allocate output tensor. AllocateOutputSetMklShape(context, 0, &output_tensor, output_tf_shape, output_mkl_shape); const T* src_data = src_tensor.flat().data(); T* dst_data = reinterpret_cast(output_tensor->flat().data()); - // Execute softmax + // Execute softmax primitive. softmax_fwd->Execute(src_data, dst_data); } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -311,6 +330,7 @@ class MklSoftmaxOp : public OpKernel { .TypeConstraint("T") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ MklSoftmaxOp); + TF_CALL_float(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES); TF_CALL_bfloat16(REGISTER_SOFTMAX_MKL_SUPPORTED_KERNELS_TYPES);