From d7dae381702471276dd4d8bdc00f96f87ab493c7 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Mon, 10 Feb 2020 10:05:37 -0800 Subject: [PATCH 1/4] Addng support of Conv backward for DNN 1.0 --- .../core/kernels/mkl_conv_grad_filter_ops.cc | 402 ++++++++++-------- .../core/kernels/mkl_conv_grad_input_ops.cc | 315 ++++++++------ tensorflow/core/kernels/mkl_conv_ops.h | 1 + 3 files changed, 406 insertions(+), 312 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index fa3264d825f..8c13730d64d 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -46,8 +47,12 @@ using mkldnn::prop_kind; using mkldnn::stream; namespace tensorflow { + typedef Eigen::ThreadPoolDevice CPUDevice; +using ConvBwdFilterDesc = mkldnn::convolution_backward_weights::desc; +using ConvBwdFilterPd = mkldnn::convolution_backward_weights::primitive_desc; + struct MklConvBwdFilterParams { memory::dims src_dims; memory::dims diff_filter_dims; @@ -80,9 +85,10 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { public: explicit MklConvBwdFilterPrimitive( const MklConvBwdFilterParams& convBwdFilterDims) - : cpu_engine_(engine::cpu, 0) { - context_.bwd_filter_stream.reset(new stream(stream::kind::eager)); - // create conv primitive + : cpu_engine_(ENGINE_CPU, 0) { + context_.bwd_filter_stream.reset(new CPU_STREAM(cpu_engine_)); + + // Create convolution backward filter primitive. if (context_.conv_bwd_filter == nullptr) { Setup(convBwdFilterDims); } @@ -90,106 +96,115 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { ~MklConvBwdFilterPrimitive() {} - // Convolution backward weights with bias - // src_data: input data buffer of src - // diff_filter_data: output data buffer of diff_filter - // diff_bias_data: output data buffer of diff_bias - // diff_dst_data: input data buffer of diff_dst + // Convolution backward weights execution with bias + // src_data: input data buffer for src + // diff_filter_data: output data buffer for diff_filter + // diff_bias_data: output data buffer for diff_bias + // diff_dst_data: input data buffer for diff_dst void Execute(const T* src_data, const T* diff_filter_data, const T* diff_bias_data, const T* diff_dst_data) { context_.src_mem->set_data_handle( static_cast(const_cast(src_data))); context_.diff_filter_mem->set_data_handle( static_cast(const_cast(diff_filter_data))); - context_.diff_bias_mem->set_data_handle( - static_cast(const_cast(diff_bias_data))); + if (diff_bias_data != nullptr) { + context_.diff_bias_mem->set_data_handle( + static_cast(const_cast(diff_bias_data))); + } context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); +#ifdef ENABLE_MKLDNN_V1 + DCHECK_EQ(context_.bwd_filter_primitives.size(), + context_.bwd_filter_primitives_args.size()); + for (size_t i = 0; i < context_.bwd_filter_primitives.size(); ++i) { + context_.bwd_filter_primitives.at(i).execute( + *context_.bwd_filter_stream, + context_.bwd_filter_primitives_args.at(i)); + } +#else context_.bwd_filter_stream->submit(context_.bwd_filter_primitives); +#endif context_.src_mem->set_data_handle(DummyData); context_.diff_filter_mem->set_data_handle(DummyData); - context_.diff_bias_mem->set_data_handle(DummyData); + if (diff_bias_data != nullptr) { + context_.diff_bias_mem->set_data_handle(DummyData); + } context_.diff_dst_mem->set_data_handle(DummyData); - return; } - // Convolution backward weights without bias + // Convolution backward weights without bias. // src_data: input data buffer of src // diff_filter_data: output data buffer of diff_filter // diff_dst_data: input data buffer of diff_dst void Execute(const T* src_data, const T* diff_filter_data, const T* diff_dst_data) { - context_.src_mem->set_data_handle( - static_cast(const_cast(src_data))); - context_.diff_filter_mem->set_data_handle( - static_cast(const_cast(diff_filter_data))); - context_.diff_dst_mem->set_data_handle( - static_cast(const_cast(diff_dst_data))); - - context_.bwd_filter_stream->submit(context_.bwd_filter_primitives); - - context_.src_mem->set_data_handle(DummyData); - context_.diff_filter_mem->set_data_handle(DummyData); - context_.diff_dst_mem->set_data_handle(DummyData); - return; + Execute(src_data, diff_filter_data, nullptr, diff_dst_data); } +#ifndef ENABLE_MKLDNN_V1 memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } - memory::format GetDiffDstMemoryFormat() const { return context_.diff_dst_fmt; } - memory::format GetDiffFilterMemoryFormat() const { return context_.diff_filter_fmt; } +#endif - // convolution primitive - std::shared_ptr - GetPrimitiveDesc() const { + std::shared_ptr GetPrimitiveDesc() const { return context_.bwd_filter_pd; } private: - // Primitive reuse context for Conv2D bwd filter op + // Primitive reuse context for Conv2D backward filter op. struct ConvBwdFilterContext { - // expected memory format for this primitive instance +#ifndef ENABLE_MKLDNN_V1 + // Expected memory format for this primitive instance memory::format src_fmt; memory::format diff_dst_fmt; memory::format diff_filter_fmt; +#endif // !ENABLE_MKLDNN_V1 - // convolution bwd input primitive - std::shared_ptr - bwd_filter_pd; - std::shared_ptr conv_bwd_filter; - - // MKLDNN memory + // MKL-DNN memory for inputs and outputs. std::shared_ptr src_mem; std::shared_ptr diff_filter_mem; std::shared_ptr diff_bias_mem; std::shared_ptr diff_dst_mem; - // desc & prmitive desc - std::shared_ptr bwd_filter_desc; - std::shared_ptr fwd_desc; - std::shared_ptr fwd_pd; + // Primitive descriptor and descriptor for convolution backward filter. + std::shared_ptr bwd_filter_pd; + std::shared_ptr bwd_filter_desc; - // memory desc: forward & backward can share same memory desc + // Primitive descriptor and descriptor for convolution forward. + std::shared_ptr fwd_pd; + std::shared_ptr fwd_desc; + + // Convolution backward filter primitive. + std::shared_ptr conv_bwd_filter; + + // Memory descriptors: forward & backward share the same memory descriptors std::shared_ptr src_md; std::shared_ptr diff_filter_md; std::shared_ptr diff_bias_md; std::shared_ptr diff_dst_md; - // MKL pipeline + // MKL-DNN pipeline for executing primitives. std::shared_ptr bwd_filter_stream; std::vector bwd_filter_primitives; +#ifdef ENABLE_MKLDNN_V1 + std::vector bwd_filter_primitives_args; +#endif + ConvBwdFilterContext() - : src_fmt(memory::format::any), + : +#ifndef ENABLE_MKLDNN_V1 + src_fmt(memory::format::any), diff_dst_fmt(memory::format::any), diff_filter_fmt(memory::format::any), +#endif src_mem(nullptr), diff_filter_mem(nullptr), diff_bias_mem(nullptr), @@ -201,84 +216,102 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { diff_filter_md(nullptr), diff_bias_md(nullptr), diff_dst_md(nullptr), - bwd_filter_stream(nullptr) {} + bwd_filter_stream(nullptr) { + } }; - // Setup Conv2d backward filter (weights) primitives. void Setup(const MklConvBwdFilterParams& convBwdFilterDims) { - // create memory descriptors for convolution data w/ no specified format + // Create memory descriptors for convolution backward filter without any + // specific format so that MKL-DNN can pick an appropriate one depending + // on the input parameters. context_.src_md.reset(new memory::desc( - {convBwdFilterDims.src_dims}, MklDnnType(), memory::format::any)); + {convBwdFilterDims.src_dims}, MklDnnType(), MEMORY_FORMAT::any)); context_.diff_dst_md.reset( new memory::desc({convBwdFilterDims.diff_dst_dims}, MklDnnType(), - memory::format::any)); + MEMORY_FORMAT::any)); context_.diff_filter_md.reset( new memory::desc({convBwdFilterDims.diff_filter_dims}, MklDnnType(), - memory::format::any)); + MEMORY_FORMAT::any)); if (!convBwdFilterDims.diff_bias_dims.empty()) context_.diff_bias_md.reset( new memory::desc({convBwdFilterDims.diff_bias_dims}, MklDnnType(), - memory::format::x)); + MEMORY_FORMAT::x)); - // create a convolution - if (!convBwdFilterDims.diff_bias_dims.empty()) { - context_.bwd_filter_desc.reset(new convolution_backward_weights::desc( - convolution_direct, *context_.src_md, *context_.diff_filter_md, - *context_.diff_bias_md, *context_.diff_dst_md, - convBwdFilterDims.strides, convBwdFilterDims.dilations, - convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, - convBwdFilterDims.padding)); - } else { - context_.bwd_filter_desc.reset(new convolution_backward_weights::desc( - convolution_direct, *context_.src_md, *context_.diff_filter_md, - *context_.diff_dst_md, convBwdFilterDims.strides, - convBwdFilterDims.dilations, convBwdFilterDims.padding_left, - convBwdFilterDims.padding_right, convBwdFilterDims.padding)); - } - - // create fwd primitive_desc - context_.fwd_desc.reset(new convolution_forward::desc( - prop_kind::forward, convolution_direct, *context_.src_md, + // Create descriptor and primitive descriptor for convolution forward. + context_.fwd_desc.reset(new ConvFwdDesc( + prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md, convBwdFilterDims.strides, convBwdFilterDims.dilations, convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, convBwdFilterDims.padding)); - context_.fwd_pd.reset(new convolution_forward::primitive_desc( - *context_.fwd_desc, cpu_engine_)); + context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); - // create backward conv primitive_desc - context_.bwd_filter_pd.reset( - new convolution_backward_weights::primitive_desc( - *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd)); + // Create descriptor and primitive descriptor for convolution bwd filter. + if (!convBwdFilterDims.diff_bias_dims.empty()) { + context_.bwd_filter_desc.reset(new ConvBwdFilterDesc( + ALGORITHM::convolution_direct, *context_.src_md, + *context_.diff_filter_md, *context_.diff_bias_md, + *context_.diff_dst_md, convBwdFilterDims.strides, + convBwdFilterDims.dilations, convBwdFilterDims.padding_left, + convBwdFilterDims.padding_right, convBwdFilterDims.padding)); + } else { + context_.bwd_filter_desc.reset(new ConvBwdFilterDesc( + ALGORITHM::convolution_direct, *context_.src_md, + *context_.diff_filter_md, *context_.diff_dst_md, + convBwdFilterDims.strides, convBwdFilterDims.dilations, + convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, + convBwdFilterDims.padding)); + } + context_.bwd_filter_pd.reset(new ConvBwdFilterPd( + *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd)); - // store the expected memory format auto bwd_filter_pd = context_.bwd_filter_pd.get(); + +#ifndef ENABLE_MKLDNN_V1 + // Store the expected memory format. context_.src_fmt = static_cast( bwd_filter_pd->src_primitive_desc().desc().data.format); context_.diff_filter_fmt = static_cast( bwd_filter_pd->diff_weights_primitive_desc().desc().data.format); context_.diff_dst_fmt = static_cast( bwd_filter_pd->diff_dst_primitive_desc().desc().data.format); +#endif // !ENABLE_MKLDNN_V1 - // create memory primitive based on dummy data - context_.src_mem.reset( - new memory(bwd_filter_pd->src_primitive_desc(), DummyData)); - context_.diff_filter_mem.reset( - new memory(bwd_filter_pd->diff_weights_primitive_desc(), DummyData)); - context_.diff_dst_mem.reset( - new memory(bwd_filter_pd->diff_dst_primitive_desc(), DummyData)); + // Create memory using dummy data. + context_.src_mem.reset(new MEMORY_CONSTRUCTOR( + bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData)); + context_.diff_filter_mem.reset(new MEMORY_CONSTRUCTOR( + bwd_filter_pd->PRIMITIVE_DESC_DIFF_WEIGHTS, cpu_engine_, DummyData)); + context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR( + bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, DummyData)); - // create convolution primitive and add it to net + // Create convolution backward filter primitive and add it to the net. if (!convBwdFilterDims.diff_bias_dims.empty()) { - context_.diff_bias_mem.reset( - new memory({{{convBwdFilterDims.diff_bias_dims}, - MklDnnType(), - memory::format::x}, - cpu_engine_}, - DummyData)); + context_.diff_bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD( + convBwdFilterDims.diff_bias_dims, T, MEMORY_FORMAT::x, cpu_engine_, + DummyData)); +#ifdef ENABLE_MKLDNN_V1 + context_.conv_bwd_filter.reset( + new convolution_backward_weights(*context_.bwd_filter_pd)); + context_.bwd_filter_primitives_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem}, + {MKLDNN_ARG_DIFF_BIAS, *context_.diff_bias_mem}, + { MKLDNN_ARG_DIFF_DST, + *context_.diff_dst_mem }}); + } else { + context_.conv_bwd_filter.reset( + new convolution_backward_weights(*context_.bwd_filter_pd)); + context_.bwd_filter_primitives_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_DIFF_WEIGHTS, *context_.diff_filter_mem}, + { MKLDNN_ARG_DIFF_DST, + *context_.diff_dst_mem }}); + } +#else context_.conv_bwd_filter.reset(new convolution_backward_weights( *context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem, *context_.diff_filter_mem, *context_.diff_bias_mem)); @@ -287,7 +320,7 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { *context_.bwd_filter_pd, *context_.src_mem, *context_.diff_dst_mem, *context_.diff_filter_mem)); } - +#endif // ENABLE_MKLDNN_V1 context_.bwd_filter_primitives.push_back(*context_.conv_bwd_filter); } @@ -305,7 +338,7 @@ class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory { if (do_not_cache) { /* Create new primitive always */ conv_bwd_filter = new MklConvBwdFilterPrimitive(convBwdFilterDims); } else { - // look into the pool for reusable primitive + // Look into the pool for reusable primitive. conv_bwd_filter = dynamic_cast*>( MklConvBwdFilterPrimitiveFactory::GetInstance().GetConvBwdFilter( convBwdFilterDims)); @@ -369,23 +402,15 @@ class MklConvCustomBackpropFilterOp void Compute(OpKernelContext* context) { try { - MklDnnData src(&cpu_engine_); - MklDnnData diff_dst(&cpu_engine_); - MklDnnData diff_filter(&cpu_engine_); // output - - // This flag indicates Conv2D or Conv3D - bool is_conv2d = (this->strides_.size() == 4); - - // Input tensors - const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; + // Input tensors. const Tensor& src_tensor = MklGetInput(context, kInputIdx); const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); - const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx); + const Tensor& diff_dst_tensor = MklGetInput(context, kDiffDstIdx); MklDnnShape src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape; GetMklShape(context, kInputIdx, &src_mkl_shape, eager_mode); GetMklShape(context, kFilterIdx, &filter_mkl_shape, eager_mode); - GetMklShape(context, kOutbpropIdx, &diff_dst_mkl_shape, eager_mode); + GetMklShape(context, kDiffDstIdx, &diff_dst_mkl_shape, eager_mode); // Allow operator-specific sanity checking of shapes. ValidateMklShapes(src_mkl_shape, filter_mkl_shape, diff_dst_mkl_shape); @@ -397,7 +422,7 @@ class MklConvCustomBackpropFilterOp TensorShape src_tf_shape = MakeInputTfShape(context, src_tensor); TensorShape filter_tf_shape = MakeFilterTfShape(context, filter_tensor); TensorShape diff_dst_tf_shape = - GetTfShape(context, kOutbpropIdx, eager_mode); + GetTfShape(context, kDiffDstIdx, eager_mode); // Corner cases: output with 0 elements and 0 batch size. Tensor* diff_filter_tensor = nullptr; @@ -412,9 +437,9 @@ class MklConvCustomBackpropFilterOp AllocateOutputSetMklShape(context, kOutputIdx, &diff_filter_tensor, diff_filter_tf_shape, diff_filter_mkl_shape, eager_mode); - CHECK_NOTNULL(diff_filter_tensor); + DCHECK(diff_filter_tensor != nullptr); - // if output tensor has more than 0 elements, we need to 0 them out. + // If output tensor has more than 0 elements, we need to 0 them out. auto diff_filter_data = diff_filter_tensor->flat().data(); for (size_t i = 0; i < diff_filter_tf_shape.num_elements(); ++i) { diff_filter_data[i] = static_cast(0); @@ -422,38 +447,44 @@ class MklConvCustomBackpropFilterOp return; } - // By default, all dims are in MKL order. Only dims in TF order - // are those with prefix tf_order. + // By default, all dims are in MKL order except those that are suffixed + // with `tf_order` memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; - memory::dims padding_left, padding_right, dilations, strides, - fwd_dst_dims; - memory::dims fwd_dst_dims_tf_order; + memory::dims padding_left, padding_right, dilations, strides; + memory::dims fwd_dst_dims, fwd_dst_dims_tf_order; // Get forward convolution parameters. - MklDnnConvUtil conv_utl(context, this->strides_, this->padding_, - this->data_format_, this->dilations_); - conv_utl.GetConvFwdSizesInMklOrder( + MklDnnConvUtil conv_util(context, this->strides_, this->padding_, + this->data_format_, this->dilations_); + conv_util.GetConvFwdSizesInMklOrder( src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims, &strides, &dilations, &fwd_dst_dims_tf_order, &fwd_dst_dims, &padding_left, &padding_right, false, is_depthwise); if (!context->status().ok()) return; + bool is_conv2d = (this->strides_.size() == 4); + auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(this->data_format_) : TFDataFormatToMklDnn3DDataFormat(this->data_format_); +#ifdef ENABLE_MKLDNN_V1 + auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); + OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, + errors::InvalidArgument("Invalid data format")); +#endif auto fwd_src_md = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetMklLayout() - : memory::desc(fwd_src_dims, MklDnnType(), tf_fmt); + : memory::desc(fwd_src_dims, MklDnnType(), MKL_FMT_TAG); - conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); if (!context->status().ok()) return; auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor() ? diff_dst_mkl_shape.GetMklLayout() - : memory::desc(diff_dst_dims, MklDnnType(), tf_fmt); + : memory::desc(diff_dst_dims, MklDnnType(), MKL_FMT_TAG); memory::dims diff_bias_dims = {}; int64 depth = 0; @@ -464,26 +495,28 @@ class MklConvCustomBackpropFilterOp : obp_tf_shape.dim_size(is_conv2d ? 3 : 4); diff_bias_dims = {static_cast(depth)}; } - for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; - MklConvBwdFilterPrimitive* conv_bwd_filter = nullptr; + // The default dilation factor for each dimension is 1 in TF and + // 0 in MKL-DNN. + for (int i = 0; i < dilations.size(); ++i) --dilations[i]; + MklConvBwdFilterParams convBwdFilterDims( fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - // MKL DNN allocates large buffers when a conv gradient filter primtive is + // MKL-DNN allocates large buffers when a conv gradient filter primtive is // created. So we don't cache conv backward primitives when the env // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true. bool do_not_cache = MklPrimitiveFactory::IsPrimitiveMemOptEnabled(); - conv_bwd_filter = MklConvBwdFilterPrimitiveFactory::Get( - convBwdFilterDims, do_not_cache); - auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc(); - // allocate output tensors: diff_fitler and diff_bias (w bias) - auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); + MklConvBwdFilterPrimitive* conv_bwd_filter = + MklConvBwdFilterPrimitiveFactory::Get(convBwdFilterDims, + do_not_cache); + + // Allocate output tensors: diff_filter and diff_bias (w bias). + auto diff_filter_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); - // diff_filter MklDnnShape diff_filter_mkl_shape; diff_filter_mkl_shape.SetMklTensor(false); @@ -491,15 +524,15 @@ class MklConvCustomBackpropFilterOp if (!is_depthwise) { // Conv2D: output_dims_mkl_order is in OIHW format. TensorShape diff_filter_tf_shape( - {bwd_output_dims[MklDnnDims::Dim_H], - bwd_output_dims[MklDnnDims::Dim_W], - bwd_output_dims[MklDnnDims::Dim_I], - bwd_output_dims[MklDnnDims::Dim_O]}); + {diff_filter_dims[MklDnnDims::Dim_H], + diff_filter_dims[MklDnnDims::Dim_W], + diff_filter_dims[MklDnnDims::Dim_I], + diff_filter_dims[MklDnnDims::Dim_O]}); AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, diff_filter_tf_shape, diff_filter_mkl_shape, eager_mode); } else { - // Depthwise Conv2d: bwd_output_dims is GOIHW format + // Depthwise Conv2d: diff_filter_dims is GOIHW format. // | TensorFlow | MKLDNN // ---------------------------------------------------------------- // filter_out_depth | depth_multiplier | depth_multiplier * @@ -511,10 +544,11 @@ class MklConvCustomBackpropFilterOp // And the GOIHW is mkldnn format, here we try to extract the TF // format, TF format is HWIO, as G = original I, so here is HWGO. TensorShape diff_filter_tf_shape( - {bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H], - bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W], - bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G], - bwd_output_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]}); + {diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_H], + diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_W], + diff_filter_dims[MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_G], + diff_filter_dims + [MklDnnFilterGroupDims::MKL_GROUP_FILTER_DIM_O]}); AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, diff_filter_tf_shape, diff_filter_mkl_shape); @@ -522,11 +556,11 @@ class MklConvCustomBackpropFilterOp } else { // Conv3D: output_dims_mkl_order is in OIDHW format. TensorShape diff_filter_tf_shape( - {bwd_output_dims[MklDnnDims3D::Dim3d_D], - bwd_output_dims[MklDnnDims3D::Dim3d_H], - bwd_output_dims[MklDnnDims3D::Dim3d_W], - bwd_output_dims[MklDnnDims3D::Dim3d_I], - bwd_output_dims[MklDnnDims3D::Dim3d_O]}); + {diff_filter_dims[MklDnnDims3D::Dim3d_D], + diff_filter_dims[MklDnnDims3D::Dim3d_H], + diff_filter_dims[MklDnnDims3D::Dim3d_W], + diff_filter_dims[MklDnnDims3D::Dim3d_I], + diff_filter_dims[MklDnnDims3D::Dim3d_O]}); AllocateOutputSetMklShape(context, 0, &diff_filter_tensor, diff_filter_tf_shape, diff_filter_mkl_shape); } @@ -537,39 +571,50 @@ class MklConvCustomBackpropFilterOp AllocateBiasGradTensor(context, diff_bias_shape, &diff_bias_tensor); } - // check if src and diff_dst need reorder + // Check if src and diff_dst need to be reordered. T* src_data = nullptr; - if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) { + MklDnnData src(&cpu_engine_); + auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc(); + if (IS_SRC_REORDER_NEEDED(fwd_src_md, bwd_filter_pd, conv_bwd_filter)) { src.SetUsrMem(fwd_src_md, &src_tensor); - src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc()); + src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + bwd_filter_pd->PRIMITIVE_DESC_SRC, cpu_engine_)); src_data = static_cast(src.GetOpMem().get_data_handle()); } else { src_data = static_cast(const_cast(src_tensor.flat().data())); } T* diff_dst_data = nullptr; - if (diff_dst_md.data.format != - conv_bwd_filter->GetDiffDstMemoryFormat()) { + MklDnnData diff_dst(&cpu_engine_); + if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_filter_pd, + conv_bwd_filter)) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc()); + diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + bwd_filter_pd->PRIMITIVE_DESC_DIFF_DST, cpu_engine_)); diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); } else { diff_dst_data = static_cast(const_cast(diff_dst_tensor.flat().data())); } - // For backward filter, convert diff_filter back to Tensorflow layout - // Here we prepare to reorder op memory back to user memory + DCHECK(!diff_filter_mkl_shape.IsMklTensor()); + auto diff_filter_format = GetOutputFormat(MKL_FMT_TAG); + auto diff_filter_md = + memory::desc(diff_filter_dims, MklDnnType(), diff_filter_format); + + // Convert diff_filter (output) back to TF layout if needed + // (i.e. reorder op memory back to user memory) + MklDnnData diff_filter(&cpu_engine_); bool diff_filter_reorder_required = false; T* diff_filter_data = nullptr; - if (GetOutputFormat(tf_fmt) != - conv_bwd_filter->GetDiffFilterMemoryFormat()) { - // Allocate diff filter tensor as Tensorflow layout - diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt), + if (IS_DIFF_FILTER_REORDER_NEEDED(diff_filter_md, diff_filter_format, + bwd_filter_pd, conv_bwd_filter)) { + // Allocate diff_filter tensor as Tensorflow layout. + diff_filter.SetUsrMem(diff_filter_dims, diff_filter_format, diff_filter_tensor); diff_filter_reorder_required = true; diff_filter.PrepareReorderToUserMemIfReq( - bwd_filter_pd->diff_weights_primitive_desc()); + bwd_filter_pd->PRIMITIVE_DESC_DIFF_WEIGHTS); diff_filter_data = static_cast(diff_filter.GetOpMem().get_data_handle()); } else { @@ -577,7 +622,7 @@ class MklConvCustomBackpropFilterOp const_cast(diff_filter_tensor->flat().data())); } - // Execute convolution filter bwd + // Execute convolution backward filter. if (bias_enabled) { T* diff_bias_data = static_cast(const_cast(diff_bias_tensor->flat().data())); @@ -587,17 +632,17 @@ class MklConvCustomBackpropFilterOp conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data); } - // Reorder diff_filter back to Tensorflow layout if necessary + // Reorder diff_filter back to Tensorflow layout if necessary. if (diff_filter_reorder_required) { diff_filter.InsertReorderToUserMem(); } - // delete primitive since it is not cached. + // Delete primitive since it is not cached. if (do_not_cache) delete conv_bwd_filter; } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -605,13 +650,12 @@ class MklConvCustomBackpropFilterOp } private: - const int kInputIndex_Filter = 1; - const int kInputIndex_InputSizes = 0; + const int kInputIdx = 0, kFilterIdx = 1, kDiffDstIdx = 2; const int kDilationH = 0, kDilationW = 1; - engine cpu_engine_ = engine(engine::cpu, 0); - // Validate input shapes. - // Function asserts that input shapes are valid. + engine cpu_engine_ = engine(ENGINE_CPU, 0); + + // Assert that input shapes are valid. void ValidateMklShapes(const MklDnnShape& input_mkl_shape, const MklDnnShape& filter_mkl_shape, const MklDnnShape& obp_mkl_shape) { @@ -656,20 +700,16 @@ class MklConvCustomBackpropFilterOp // Output layout is Tensorflow's filter layout // Conv2D: HWIO; Conv3D: DHWIO; Depthwise Conv: HWIGO - memory::format GetOutputFormat(const memory::format data_format) { - return is_depthwise - ? memory::format::hwigo - : ((this->strides_.size() == 4) ? memory::format::hwio - : memory::format::dhwio); + MEMORY_FORMAT GetOutputFormat(const MEMORY_FORMAT data_format) { + return is_depthwise ? MEMORY_FORMAT::hwigo + : ((this->strides_.size() == 4) ? MEMORY_FORMAT::hwio + : MEMORY_FORMAT::dhwio); } - // Allocate output tensor. - void AllocateOutputTensor( - OpKernelContext* context, - const convolution_backward_weights::primitive_desc& conv_pd, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); + void AllocateOutputTensor(OpKernelContext* context, + const memory::dims& output_dims_mkl_order, + Tensor** output_tensor) { + DCHECK(output_tensor != nullptr); // For BackpropFilter, we convert the output tensor back in Tensorflow // layout. Because typically, BackpropFilter is the last operator in the @@ -689,11 +729,10 @@ class MklConvCustomBackpropFilterOp output_mkl_shape); } - // Allocate tensor for bias grad void AllocateBiasGradTensor(OpKernelContext* context, const TensorShape& bias_grad_shape, Tensor** bias_grad_tensor) { - CHECK_NOTNULL(bias_grad_tensor); + DCHECK(bias_grad_tensor); MklDnnShape bias_grad_mkl_shape; bias_grad_mkl_shape.SetMklTensor(false); @@ -742,6 +781,7 @@ class MklConvCustomBackpropFilterOp TF_CALL_float(REGISTER_MKL_FILTER_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_FILTER_KERNELS); + #undef REGISTER_MKL_FILTER_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 02456685341..fc2ec3356be 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -21,6 +21,7 @@ limitations under the License. #define USE_EIGEN_TENSOR #define EIGEN_USE_THREADS + #include #include @@ -39,6 +40,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" @@ -50,9 +52,11 @@ using mkldnn::prop_kind; using mkldnn::stream; namespace tensorflow { -typedef Eigen::ThreadPoolDevice CPUDevice; -/// utility classes enabling primitive reuse for backward conv ops. +using ConvBwdDataDesc = mkldnn::convolution_backward_data::desc; +using ConvBwdDataPd = mkldnn::convolution_backward_data::primitive_desc; + +// Utility classes for enabling primitive reuse for conv bwd input. struct MklConvBwdInputParams { memory::dims diff_src_dims; memory::dims filter_dims; @@ -82,20 +86,21 @@ class MklConvBwdInputPrimitive : public MklPrimitive { public: explicit MklConvBwdInputPrimitive( const MklConvBwdInputParams& convBwdInputDims) - : cpu_engine_(engine::cpu, 0) { - context_.bwd_input_stream.reset(new stream(stream::kind::eager)); + : cpu_engine_(ENGINE_CPU, 0) { + context_.bwd_input_stream.reset(new CPU_STREAM(cpu_engine_)); - // create conv primitive + // Create conv bwd input primitive if (context_.conv_bwd_input == nullptr) { Setup(convBwdInputDims); } } + ~MklConvBwdInputPrimitive() {} - // Convolution backward filter (weights) - // diff_src_data: output data buffer of diff_src - // filter_data: input data buffer of filter (weights) - // diff_dst_data: input data buffer of dst + // Convolution backward input (data) execution. + // diff_src_data: output data buffer for diff_src + // filter_data: input data buffer for filter (weights) + // diff_dst_data: input data buffer for dst // Bias does not matter here void Execute(const T* diff_src_data, const T* filter_data, const T* diff_dst_data) { @@ -106,60 +111,79 @@ class MklConvBwdInputPrimitive : public MklPrimitive { context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); +#ifdef ENABLE_MKLDNN_V1 + DCHECK_EQ(context_.bwd_input_primitives.size(), + context_.bwd_input_primitives_args.size()); + for (size_t i = 0; i < context_.bwd_input_primitives.size(); ++i) { + context_.bwd_input_primitives.at(i).execute( + *context_.bwd_input_stream, context_.bwd_input_primitives_args.at(i)); + } +#else context_.bwd_input_stream->submit(context_.bwd_input_primitives); +#endif // ENABLE_MKLDNN_V1 - // set back data handle + // Set data handle back to DummyData. context_.diff_src_mem->set_data_handle(DummyData); context_.filter_mem->set_data_handle(DummyData); context_.diff_dst_mem->set_data_handle(DummyData); return; } +#ifndef ENABLE_MKLDNN_V1 memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; } - memory::format GetDiffDstMemoryFormat() const { return context_.diff_dst_fmt; } +#endif // !ENABLE_MKLDNN_V1 - std::shared_ptr - GetPrimitiveDesc() const { + std::shared_ptr GetPrimitiveDesc() const { return context_.bwd_input_pd; } private: - // Primitive reuse context for Conv Bwd Input op + // Primitive reuse context for conv bwd input. struct ConvBwdInputContext { - // expected memory format for this primitive instance +#ifndef ENABLE_MKLDNN_V1 + // Expected memory format for this primitive instance. memory::format filter_fmt; memory::format diff_dst_fmt; +#endif - // MKLDNN memory + // MKL-DNN memory. std::shared_ptr diff_src_mem; std::shared_ptr filter_mem; std::shared_ptr diff_dst_mem; - // convolution primitive - std::shared_ptr - bwd_input_pd; + // Conv backward input primitive descriptor and descriptor. + std::shared_ptr bwd_input_pd; + std::shared_ptr bwd_input_desc; + + // Primitive descriptor and descriptor for conv fwd + std::shared_ptr fwd_pd; + std::shared_ptr fwd_desc; + + // Conv bwd input primitive. std::shared_ptr conv_bwd_input; - // desc & prmitive desc - std::shared_ptr bwd_input_desc; - std::shared_ptr fwd_desc; - std::shared_ptr fwd_pd; - - // memory desc: forward & backward can share same memory::desc + // Memory descriptors: forward & backward share the same descriptors. std::shared_ptr diff_src_md; std::shared_ptr filter_md; std::shared_ptr diff_dst_md; - // MKL pipeline + // MKL-DNN pipeline for executing primitives. std::shared_ptr bwd_input_stream; std::vector bwd_input_primitives; +#ifdef ENABLE_MKLDNN_V1 + std::vector> bwd_input_primitives_args; +#endif // ENABLE_MKLDNN_V1 + ConvBwdInputContext() - : filter_fmt(memory::format::any), + : +#ifndef ENABLE_MKLDNN_V1 + filter_fmt(memory::format::any), diff_dst_fmt(memory::format::any), +#endif diff_src_mem(nullptr), filter_mem(nullptr), diff_dst_mem(nullptr), @@ -171,49 +195,53 @@ class MklConvBwdInputPrimitive : public MklPrimitive { diff_src_md(nullptr), filter_md(nullptr), diff_dst_md(nullptr), - bwd_input_stream(nullptr) {} + bwd_input_stream(nullptr) { + } }; void Setup(const MklConvBwdInputParams& convBwdInputDims) { - // create memory descriptors for convolution data w/ no specified format - context_.diff_src_md.reset( - new memory::desc({convBwdInputDims.diff_src_dims}, MklDnnType(), - memory::format::any)); + // Create memory descriptors for conv bwd input without any specified + // format so that MKL-DNN can pick an appropriate one depending on the + // input parameters. + context_.diff_src_md.reset(new memory::desc( + {convBwdInputDims.diff_src_dims}, MklDnnType(), MEMORY_FORMAT::any)); context_.filter_md.reset(new memory::desc( - {convBwdInputDims.filter_dims}, MklDnnType(), memory::format::any)); - context_.diff_dst_md.reset( - new memory::desc({convBwdInputDims.diff_dst_dims}, MklDnnType(), - memory::format::any)); + {convBwdInputDims.filter_dims}, MklDnnType(), MEMORY_FORMAT::any)); + context_.diff_dst_md.reset(new memory::desc( + {convBwdInputDims.diff_dst_dims}, MklDnnType(), MEMORY_FORMAT::any)); - // create convolution primitives - context_.bwd_input_desc.reset(new convolution_backward_data::desc( - convolution_direct, *context_.diff_src_md, *context_.filter_md, - *context_.diff_dst_md, convBwdInputDims.strides, - convBwdInputDims.dilations, convBwdInputDims.padding_left, - convBwdInputDims.padding_right, convBwdInputDims.padding)); - - context_.fwd_desc.reset(new convolution_forward::desc( - prop_kind::forward, convolution_direct, *context_.diff_src_md, + // Create descriptors for both conv fwd and conv bwd input. + context_.bwd_input_desc.reset(new ConvBwdDataDesc( + ALGORITHM::convolution_direct, *context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, convBwdInputDims.dilations, convBwdInputDims.padding_left, convBwdInputDims.padding_right, convBwdInputDims.padding)); - context_.fwd_pd.reset(new convolution_forward::primitive_desc( - *context_.fwd_desc, cpu_engine_)); + context_.fwd_desc.reset(new ConvFwdDesc( + prop_kind::forward, ALGORITHM::convolution_direct, + *context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md, + convBwdInputDims.strides, convBwdInputDims.dilations, + convBwdInputDims.padding_left, convBwdInputDims.padding_right, + convBwdInputDims.padding)); - // create backward conv prim desc - context_.bwd_input_pd.reset(new convolution_backward_data::primitive_desc( + // Create primitive descriptors for conv fwd and conv bwd input. + context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); + context_.bwd_input_pd.reset(new ConvBwdDataPd( *context_.bwd_input_desc, cpu_engine_, *context_.fwd_pd)); - // create memory primitive based on dummy data - context_.diff_src_mem.reset(new memory( - context_.bwd_input_pd.get()->diff_src_primitive_desc(), DummyData)); - context_.filter_mem.reset(new memory( - context_.bwd_input_pd.get()->weights_primitive_desc(), DummyData)); - context_.diff_dst_mem.reset(new memory( - context_.bwd_input_pd.get()->diff_dst_primitive_desc(), DummyData)); + // Create memory using dummy data. + context_.diff_src_mem.reset(new MEMORY_CONSTRUCTOR( + context_.bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_SRC, cpu_engine_, + DummyData)); + context_.filter_mem.reset(new MEMORY_CONSTRUCTOR( + context_.bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, + DummyData)); + context_.diff_dst_mem.reset(new MEMORY_CONSTRUCTOR( + context_.bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_, + DummyData)); - // store the expected memory format +#ifndef ENABLE_MKLDNN_V1 + // Store the expected memory format. context_.filter_fmt = static_cast(context_.bwd_input_pd.get() ->weights_primitive_desc() @@ -224,11 +252,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive { ->diff_dst_primitive_desc() .desc() .data.format); +#endif // !ENABLE_MKLDNN_V1 - // create convolution primitive and add it to net +// Create conv bwd input primitive and add it to the net +#ifdef ENABLE_MKLDNN_V1 + context_.conv_bwd_input.reset( + new convolution_backward_data(*context_.bwd_input_pd)); + context_.bwd_input_primitives_args.push_back( + {{MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, + { MKLDNN_ARG_DIFF_SRC, + *context_.diff_src_mem }}); +#else context_.conv_bwd_input.reset(new convolution_backward_data( *context_.bwd_input_pd, *context_.diff_dst_mem, *context_.filter_mem, *context_.diff_src_mem)); +#endif // ENABLE_MKLDNN_V1 context_.bwd_input_primitives.push_back(*context_.conv_bwd_input); } @@ -248,10 +287,10 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory { const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) { MklConvBwdInputPrimitive* conv_bwd_input = nullptr; - if (do_not_cache) { /* Always allocate primitive */ + if (do_not_cache) { // Always allocate primitive. conv_bwd_input = new MklConvBwdInputPrimitive(convBwdInputDims); } else { - // look into the pool for reusable primitive + // look into the pool for reusable primitive. conv_bwd_input = dynamic_cast*>( MklConvBwdInputPrimitiveFactory::GetInstance().GetConvBwdInput( convBwdInputDims)); @@ -308,14 +347,7 @@ class MklConvCustomBackpropInputOp void Compute(OpKernelContext* context) { try { - MklDnnData filter(&cpu_engine); - MklDnnData diff_dst(&cpu_engine); - - // This flag indicate Conv2D or Conv3D - bool is_conv2d = (this->strides_.size() == 4); - - // Input tensors - const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; + // Input tensors. const Tensor& src_tensor = MklGetInput(context, kInputIdx); const Tensor& filter_tensor = MklGetInput(context, kFilterIdx); const Tensor& diff_dst_tensor = MklGetInput(context, kOutbpropIdx); @@ -350,9 +382,9 @@ class MklConvCustomBackpropInputOp AllocateOutputSetMklShape(context, kOutputIdx, &diff_src_tensor, diff_src_tf_shape, diff_src_mkl_shape, eager_mode); - CHECK_NOTNULL(diff_src_tensor); + DCHECK(diff_src_tensor != nullptr); - // if output tensor has more than 0 elements, we need to 0 them out. + // If output tensor has more than 0 elements, we need to 0 them out. auto diff_src_data = diff_src_tensor->flat().data(); for (size_t i = 0; i < diff_src_tf_shape.num_elements(); ++i) { diff_src_data[i] = static_cast(0); @@ -360,28 +392,36 @@ class MklConvCustomBackpropInputOp return; } - // By default, all dims are in MKL order. Only dims in TF order - // are those with postfix tf_order. + // By default, all dims are in MKL order except those that are suffixed + // with `tf_order`. memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims; memory::dims padding_left, padding_right, dilations, strides; memory::dims fwd_output_dims, fwd_output_dims_tf_order; - // Get forward convolution parameters. - MklDnnConvUtil conv_utl(context, this->strides_, this->padding_, - this->data_format_, this->dilations_); - conv_utl.GetConvFwdSizesInMklOrder( + // Get conv fwd parameters. + MklDnnConvUtil conv_util(context, this->strides_, this->padding_, + this->data_format_, this->dilations_); + conv_util.GetConvFwdSizesInMklOrder( src_tf_shape, filter_tf_shape, &fwd_src_dims, &fwd_filter_dims, &strides, &dilations, &fwd_output_dims_tf_order, &fwd_output_dims, &padding_left, &padding_right, false, is_depthwise); if (!context->status().ok()) return; - // Create Convolution forward descriptor since Convolution backward - // API needs it. For that, we first need to create input, filter - // and output memory descriptors. + bool is_conv2d = (this->strides_.size() == 4); + + // Create conv fwd descriptor since conv bwd input API needs it. + // For that, we first need to create input, filter and output memory + // descriptors. auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(this->data_format_) : TFDataFormatToMklDnn3DDataFormat(this->data_format_); +#ifdef ENABLE_MKLDNN_V1 + auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); + OP_REQUIRES(context, mkl_fmt_tag != memory::format_tag::undef, + errors::InvalidArgument("Invalid data format")); +#endif // ENABLE_MKLDNN_V1 + // If filter is in MKL layout, then simply grab filter layout; // otherwise, construct filter in TF layout. // For TF layout, filter is in HWIO format. @@ -389,42 +429,47 @@ class MklConvCustomBackpropInputOp filter_mkl_shape.IsMklTensor() ? filter_mkl_shape.GetMklLayout() : memory::desc(fwd_filter_dims, MklDnnType(), - is_depthwise - ? memory::hwigo - : (is_conv2d ? memory::format::hwio - : memory::format::dhwio)); + is_depthwise ? MEMORY_FORMAT::hwigo + : (is_conv2d ? MEMORY_FORMAT::hwio + : MEMORY_FORMAT::dhwio)); - conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); + conv_util.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims); if (!context->status().ok()) return; + auto diff_dst_md = diff_dst_mkl_shape.IsMklTensor() ? diff_dst_mkl_shape.GetMklLayout() - : memory::desc(diff_dst_dims, MklDnnType(), tf_fmt); - for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; + : memory::desc(diff_dst_dims, MklDnnType(), MKL_FMT_TAG); + + // The default dilation factor for each dimension is 1 in TF and + // 0 in MKL-DNN. + for (int i = 0; i < dilations.size(); ++i) --dilations[i]; - MklConvBwdInputPrimitive* conv_bwd_input = nullptr; MklConvBwdInputParams convBwdInputDims( fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - // We don't cache those primitves if the env variable + // We don't cache those primitives if the environment variable // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor - // includes potentialy large buffers. MKL DNN allocates buffers + // includes potentialy large buffers. MKL-DNN allocates buffers // in the following cases // 1. Legacy CPU without AVX512/AVX2, or // 2. 1x1 convolution with stride != 1 bool do_not_cache = MklPrimitiveFactory::IsPrimitiveMemOptEnabled() && (MklPrimitiveFactory::IsLegacyPlatform() || IsConv1x1StrideNot1(fwd_filter_dims, strides)); - conv_bwd_input = MklConvBwdInputPrimitiveFactory::Get(convBwdInputDims, - do_not_cache); - auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc(); - // allocate output tensor - auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc(); + MklConvBwdInputPrimitive* conv_bwd_input = + MklConvBwdInputPrimitiveFactory::Get(convBwdInputDims, + do_not_cache); + + auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc(); + auto diff_src_pd = bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_SRC; auto bwd_diff_src_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims); auto bwd_diff_src_format = GetOutputFormat(tf_fmt); + + // Allocate output tensor. MklDnnShape diff_src_mkl_shape; diff_src_mkl_shape.SetMklTensor(true); diff_src_mkl_shape.SetMklLayout(&diff_src_pd); @@ -443,12 +488,14 @@ class MklConvCustomBackpropInputOp T* diff_src_data = static_cast(const_cast(diff_src_tensor->flat().data())); - // check if filter and diff_dst need reorder + // Check if filter and diff_dst need to be reordered. T* filter_data = nullptr; - if (fwd_filter_md.data.format != - conv_bwd_input->GetFilterMemoryFormat()) { + MklDnnData filter(&cpu_engine_); + if (IS_FILTER_REORDER_NEEDED(fwd_filter_md, bwd_input_pd, + conv_bwd_input)) { filter.SetUsrMem(fwd_filter_md, &filter_tensor); - filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc()); + filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + bwd_input_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_)); filter_data = static_cast(filter.GetOpMem().get_data_handle()); } else { filter_data = @@ -456,16 +503,19 @@ class MklConvCustomBackpropInputOp } T* diff_dst_data = nullptr; - if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) { + MklDnnData diff_dst(&cpu_engine_); + if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bwd_input_pd, + conv_bwd_input)) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc()); + diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + bwd_input_pd.get()->PRIMITIVE_DESC_DIFF_DST, cpu_engine_)); diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); } else { diff_dst_data = static_cast(const_cast(diff_dst_tensor.flat().data())); } - // execute convolution input bwd + // Execute conv bwd input primitive. if (!eager_mode) { conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); } else { @@ -475,25 +525,27 @@ class MklConvCustomBackpropInputOp static_cast(const_cast(tmp_tensor.flat().data())); conv_bwd_input->Execute(tmp_data, filter_data, diff_dst_data); auto output_tf_md = diff_src_mkl_shape.GetTfLayout(); - auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine); - mkldnn::reorder::primitive_desc reorder_pd = - mkldnn::reorder::primitive_desc(diff_src_pd, output_tf_pd); - std::vector net; - memory* tmp_data_mem = new memory(diff_src_pd, tmp_data); - memory* dst_data_mem = new memory(output_tf_pd, diff_src_data); - net.push_back( - mkldnn::reorder(reorder_pd, *tmp_data_mem, *dst_data_mem)); - stream(stream::kind::eager).submit(net).wait(); +#ifndef ENABLE_MKLDNN_V1 + auto output_tf_pd = memory::primitive_desc(output_tf_md, cpu_engine_); +#endif + ReorderPd reorder_pd = + REORDER_PD_CONSTRUCTOR(diff_src_pd, OUTPUT_TF_MD, cpu_engine_); + memory* tmp_data_mem = + new MEMORY_CONSTRUCTOR(diff_src_pd, cpu_engine_, tmp_data); + memory* dst_data_mem = + new MEMORY_CONSTRUCTOR(OUTPUT_TF_MD, cpu_engine_, diff_src_data); + CreateAndExecuteReorder(reorder_pd, *tmp_data_mem, *dst_data_mem, + cpu_engine_); } - // delete primitive since it is not cached. + // Delete primitive since it is not cached. if (do_not_cache) { delete conv_bwd_input; } } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); @@ -501,12 +553,12 @@ class MklConvCustomBackpropInputOp } private: - const int kInputIndex_Filter = 1, kInputIndex_InputSizes = 0; + const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2; const int kDilationH = 0, kDilationW = 1; - engine cpu_engine = engine(engine::cpu, 0); - // Validate input shapes. - // Function asserts that input shapes are valid. + engine cpu_engine_ = engine(ENGINE_CPU, 0); + + // Assert that input shapes are valid. void ValidateMklShapes(const MklDnnShape& input_mkl_shape, const MklDnnShape& filter_mkl_shape, const MklDnnShape& obp_mkl_shape) { @@ -532,7 +584,7 @@ class MklConvCustomBackpropInputOp // Get TensorFlow shape of filter tensor. TensorShape MakeFilterTfShape(OpKernelContext* context, const Tensor& filter_tensor) { - return GetTfShape(context, kInputIndex_Filter, eager_mode); + return GetTfShape(context, kFilterIdx, eager_mode); } // Get the Tensorflow shape of Output (diff_src), @@ -543,30 +595,29 @@ class MklConvCustomBackpropInputOp return input_shape; } - // Get the Tensorflow shape of Output (diff_src), - // which is same as shape of Conv 'input'. const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims, const memory::dims& fwd_filter_dims) { return fwd_input_dims; } // Output layout is Tensorflow's layout in data format order. - memory::format GetOutputFormat(const memory::format data_format) { + MKL_TENSOR_FORMAT GetOutputFormat(const MKL_TENSOR_FORMAT data_format) { return data_format; } - // Allocate output tensor. - void AllocateOutputTensor( - OpKernelContext* context, - const convolution_backward_data::primitive_desc& conv_pd, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); + // TODO(bhavanis): Move this function to mkl_util.h since it is common to + // both the forward and backward implementations + void AllocateOutputTensor(OpKernelContext* context, + const ConvBwdDataPd& conv_pd, + const memory::dims& output_dims_mkl_order, + MKL_TENSOR_FORMAT output_tf_format, + Tensor** output_tensor) { + DCHECK(output_tensor != nullptr); // Output primitive descriptor for backward data is diff_src. - auto dst_pd = conv_pd.diff_src_primitive_desc(); + auto dst_pd = conv_pd.PRIMITIVE_DESC_DIFF_SRC; - // Allocate shape of Mkl tensor. + // Allocate shape of MKL tensor. MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(true); output_mkl_shape.SetMklLayout(&dst_pd); @@ -608,8 +659,10 @@ class MklConvCustomBackpropInputOp .TypeConstraint("T") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ MklConvCustomBackpropInputOp); + TF_CALL_float(REGISTER_MKL_CPU_KERNELS); TF_CALL_bfloat16(REGISTER_MKL_CPU_KERNELS); + #undef REGISTER_MKL_CPU_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index 8553111672d..4e6f64c56b0 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -55,6 +55,7 @@ namespace tensorflow { #define MKLDNN_SIZE_DTYPE int #endif // ENABLE_MKLDNN_V1 +using ConvFwdDesc = mkldnn::convolution_forward::desc; using ConvFwdPd = mkldnn::convolution_forward::primitive_desc; class MklDnnConvUtil { From 9a53a987145bee6906ba7f2035e4fbe2f28c5836 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Thu, 13 Feb 2020 10:03:38 -0800 Subject: [PATCH 2/4] Added a method for refactoring --- tensorflow/core/util/mkl_util.h | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index e4450ee8a56..6d6794b6c33 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -730,7 +730,7 @@ inline Status ConvertMklToTF(OpKernelContext* context, } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string(__FILE__) + ":" + std::to_string(__LINE__); LOG(FATAL) << "Operation received an exception: " << error_msg; } } @@ -2155,6 +2155,17 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims, ((strides[0] != 1) || (strides[1] != 1))); } +#ifdef ENABLE_MKLDNN_V1 +void execute_primitives( + std::vector& primitives, std::shared_ptr stream, + std::vector>& net_args) { + DCHECK_EQ(primitives.size(), net_args.size(); + for (size_t i = 0; i < primitives.size(); ++i) { + primitives.at(i).execute(*stream, net_args.at(i)); + } +} +#endif // ENABLE_MKLDNN_V1 + } // namespace tensorflow #endif // INTEL_MKL #endif // TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ From b6407acdd52c78d931caffdd646ca8631402bdc1 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Thu, 13 Feb 2020 14:45:09 -0800 Subject: [PATCH 3/4] Fixing the clang format --- tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 6 +++--- tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 6 +++--- tensorflow/core/util/mkl_util.h | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 8c13730d64d..09ac4011fbe 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -640,9 +640,9 @@ class MklConvCustomBackpropFilterOp // Delete primitive since it is not cached. if (do_not_cache) delete conv_bwd_filter; } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + ", message: " + - string(e.message) + ", in file " + string(__FILE__) + - ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index fc2ec3356be..3115bff01bc 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -543,9 +543,9 @@ class MklConvCustomBackpropInputOp delete conv_bwd_input; } } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + ", message: " + - string(e.message) + ", in file " + string(__FILE__) + - ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + + ", message: " + string(e.message) + ", in file " + + string(__FILE__) + ":" + std::to_string(__LINE__); OP_REQUIRES_OK( context, errors::Aborted("Operation received an exception:", error_msg)); diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 6d6794b6c33..e123893ba3a 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -730,7 +730,7 @@ inline Status ConvertMklToTF(OpKernelContext* context, } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string(__FILE__) + ":" + std::to_string(__LINE__); LOG(FATAL) << "Operation received an exception: " << error_msg; } } @@ -2159,7 +2159,7 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims, void execute_primitives( std::vector& primitives, std::shared_ptr stream, std::vector>& net_args) { - DCHECK_EQ(primitives.size(), net_args.size(); + DCHECK_EQ(primitives.size(), net_args.size()); for (size_t i = 0; i < primitives.size(); ++i) { primitives.at(i).execute(*stream, net_args.at(i)); } From aba36d0286152da96d8cc3e4e8939b60374d7554 Mon Sep 17 00:00:00 2001 From: mbhuiyan Date: Thu, 13 Feb 2020 15:02:33 -0800 Subject: [PATCH 4/4] Calling refactored method --- tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc | 9 ++------- tensorflow/core/kernels/mkl_conv_grad_input_ops.cc | 8 ++------ 2 files changed, 4 insertions(+), 13 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 09ac4011fbe..b886c3f8f1f 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -115,13 +115,8 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { static_cast(const_cast(diff_dst_data))); #ifdef ENABLE_MKLDNN_V1 - DCHECK_EQ(context_.bwd_filter_primitives.size(), - context_.bwd_filter_primitives_args.size()); - for (size_t i = 0; i < context_.bwd_filter_primitives.size(); ++i) { - context_.bwd_filter_primitives.at(i).execute( - *context_.bwd_filter_stream, - context_.bwd_filter_primitives_args.at(i)); - } + execute_primitives(context_.bwd_filter_primitives, + context_.bwd_filter_stream, context_.bwd_filter_primitives_args); #else context_.bwd_filter_stream->submit(context_.bwd_filter_primitives); #endif diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index 3115bff01bc..7c1f4312fda 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -112,12 +112,8 @@ class MklConvBwdInputPrimitive : public MklPrimitive { static_cast(const_cast(diff_dst_data))); #ifdef ENABLE_MKLDNN_V1 - DCHECK_EQ(context_.bwd_input_primitives.size(), - context_.bwd_input_primitives_args.size()); - for (size_t i = 0; i < context_.bwd_input_primitives.size(); ++i) { - context_.bwd_input_primitives.at(i).execute( - *context_.bwd_input_stream, context_.bwd_input_primitives_args.at(i)); - } + execute_primitives(context_.bwd_input_primitives, context_.bwd_input_stream, + context_.bwd_input_primitives_args); #else context_.bwd_input_stream->submit(context_.bwd_input_primitives); #endif // ENABLE_MKLDNN_V1