From 14b14ab32dd3d07f7e0a7d375a6b6d68a6831ccd Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Tue, 9 Jul 2019 15:54:04 -0700 Subject: [PATCH 1/6] Enabled Conv2D fprop for MKL-DNN v1.0. --- tensorflow/core/graph/mkl_layout_pass.cc | 34 +- tensorflow/core/kernels/mkl_conv_ops.cc | 468 ++++++++++++++++++++++- tensorflow/core/kernels/mkl_conv_ops.h | 40 +- 3 files changed, 511 insertions(+), 31 deletions(-) diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index df3cf19e2c0..7ec8e3eea32 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -351,9 +351,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mul = "Mul"; csinfo_.squared_difference = "SquaredDifference"; csinfo_.sub = "Sub"; - // End - element-wise ops. See note above. +// End - element-wise ops. See note above. - // NOTE: names are alphabetically sorted. +// NOTE: names are alphabetically sorted. +#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAddN, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -388,10 +389,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass { {csinfo_.conjugate_transpose, mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), CopyAttrsTranspose, AlwaysRewrite, kRewriteForOpNameChange}); +#endif // ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConvCheckConstFilter, AlwaysRewrite, kRewriteForLayoutPropagation}); +#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, CopyAttrsConvCheckConstFilter, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -632,18 +635,20 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), CopyAttrsRequantize, AlwaysRewrite, kRewriteForLayoutPropagation}); - // Disable these two MKL operators for now due to some test failures caused - // by these two ops - /* - rinfo_.push_back({csinfo_.tanh, - mkl_op_registry::GetMklOpName(csinfo_.tanh), - CopyAttrsDataType, AlwaysRewrite, - kRewriteForLayoutPropagation}); - rinfo_.push_back({csinfo_.tanh_grad, - mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), - CopyAttrsDataType, AlwaysRewrite, - kRewriteForLayoutPropagation}); - */ +#endif // ENABLE_MKLDNN_V1 +// Disable these two MKL operators for now due to some test failures caused +// by these two ops +/* +rinfo_.push_back({csinfo_.tanh, + mkl_op_registry::GetMklOpName(csinfo_.tanh), + CopyAttrsDataType, AlwaysRewrite, + kRewriteForLayoutPropagation}); +rinfo_.push_back({csinfo_.tanh_grad, + mkl_op_registry::GetMklOpName(csinfo_.tanh_grad), + CopyAttrsDataType, AlwaysRewrite, + kRewriteForLayoutPropagation}); +*/ +#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back( {csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsReshape, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -744,6 +749,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // CheckForMklOp FuseConv3D, CopyAttrsConv}); +#endif // ENABLE_MKLDNN_V1 } // Standard interface to run pass diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 14344da0560..39cc4da3ce0 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -24,8 +24,8 @@ limitations under the License. #include #include -#include "mkldnn.hpp" #include "absl/strings/str_join.h" +#include "mkldnn.hpp" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -50,7 +50,9 @@ limitations under the License. using mkldnn::prop_kind; using mkldnn::stream; using mkldnn::convolution_forward; +#ifndef ENABLE_MKLDNN_V1 using mkldnn::convolution_direct; +#endif namespace tensorflow { @@ -93,6 +95,16 @@ typedef mkldnn::convolution_forward::primitive_desc ConvFwdPd; template class MklConvFwdPrimitive : public MklPrimitive { public: +#ifdef ENABLE_MKLDNN_V1 + explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) + : cpu_engine_(engine::kind::cpu, 0) { + context_.fwd_stream.reset(new stream(cpu_engine_)); + // Create conv primitive + if (context_.conv_fwd == nullptr) { + Setup(convFwdDims); + } + } +#else explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) : cpu_engine_(engine::cpu, 0) { context_.fwd_stream.reset(new stream(stream::kind::eager)); @@ -101,6 +113,7 @@ class MklConvFwdPrimitive : public MklPrimitive { Setup(convFwdDims); } } +#endif ~MklConvFwdPrimitive() {} @@ -119,7 +132,16 @@ class MklConvFwdPrimitive : public MklPrimitive { static_cast(const_cast(bias_data))); context_.dst_mem->set_data_handle( static_cast(const_cast(dst_data))); +#ifdef ENABLE_MKLDNN_V1 + CHECK_EQ(context_.fwd_primitives.size(), + context_.fwd_primitives_args.size()); + for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { + context_.fwd_primitives.at(i).execute(*context_.fwd_stream, + context_.fwd_primitives_args.at(i)); + } +#else context_.fwd_stream->submit(context_.fwd_primitives); +#endif // After exec, set data handle back context_.src_mem->set_data_handle(DummyData); @@ -142,7 +164,16 @@ class MklConvFwdPrimitive : public MklPrimitive { static_cast(const_cast(filter_data))); context_.dst_mem->set_data_handle( static_cast(const_cast(dst_data))); +#ifdef ENABLE_MKLDNN_V1 + CHECK_EQ(context_.fwd_primitives.size(), + context_.fwd_primitives_args.size()); + for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { + context_.fwd_primitives.at(i).execute(*context_.fwd_stream, + context_.fwd_primitives_args.at(i)); + } +#else context_.fwd_stream->submit(context_.fwd_primitives); +#endif // After execution, set data handle back context_.src_mem->set_data_handle(DummyData); @@ -150,9 +181,13 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(DummyData); } +#ifndef ENABLE_MKLDNN_V1 + // In MKL-DNN v1.0, memory format tags only provide a partial description + // of the memory layout. Hence, these functions are disabled for v1.0. memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; } +#endif std::shared_ptr GetPrimitiveDesc() const { return context_.fwd_pd; @@ -161,9 +196,11 @@ class MklConvFwdPrimitive : public MklPrimitive { private: // Primitive reuse context for Conv2D Fwd op struct ConvFwdContext { +#ifndef ENABLE_MKLDNN_V1 // Expected memory format for this primitive instance memory::format src_fmt; memory::format filter_fmt; +#endif // MKLDNN memory std::shared_ptr src_mem; @@ -187,9 +224,16 @@ class MklConvFwdPrimitive : public MklPrimitive { std::shared_ptr fwd_stream; std::vector fwd_primitives; +#ifdef ENABLE_MKLDNN_V1 + std::vector> fwd_primitives_args; +#endif + ConvFwdContext() - : src_fmt(memory::format::any), + : +#ifndef ENABLE_MKLDNN_V1 + src_fmt(memory::format::any), filter_fmt(memory::format::any), +#endif src_mem(nullptr), filter_mem(nullptr), bias_mem(nullptr), @@ -200,34 +244,64 @@ class MklConvFwdPrimitive : public MklPrimitive { bias_md(nullptr), fwd_pd(nullptr), conv_fwd(nullptr), - fwd_stream(nullptr) {} + fwd_stream(nullptr) { + } }; void Setup(const MklConvFwdParams& convFwdDims) { // Create memory descriptors for convolution data w/ no specified format context_.src_md.reset(new memory::desc( +#ifdef ENABLE_MKLDNN_V1 + {convFwdDims.src_dims}, MklDnnType(), memory::format_tag::any)); +#else {convFwdDims.src_dims}, MklDnnType(), memory::format::any)); +#endif context_.filter_md.reset(new memory::desc( +#ifdef ENABLE_MKLDNN_V1 + {convFwdDims.filter_dims}, MklDnnType(), + memory::format_tag::any)); +#else {convFwdDims.filter_dims}, MklDnnType(), memory::format::any)); +#endif context_.dst_md.reset(new memory::desc( +#ifdef ENABLE_MKLDNN_V1 + {convFwdDims.dst_dims}, MklDnnType(), + memory::format_tag::any)); +#else {convFwdDims.dst_dims}, MklDnnType(), memory::format::any)); +#endif if (!convFwdDims.bias_dims.empty()) context_.bias_md.reset(new memory::desc( +#ifdef ENABLE_MKLDNN_V1 + {convFwdDims.bias_dims}, MklDnnType(), + memory::format_tag::any)); +#else {convFwdDims.bias_dims}, MklDnnType(), memory::format::any)); +#endif // Create a convolution if (!convFwdDims.bias_dims.empty()) { context_.fwd_desc.reset(new convolution_forward::desc( +#ifdef ENABLE_MKLDNN_V1 + prop_kind::forward, mkldnn::algorithm::convolution_direct, + *context_.src_md, +#else prop_kind::forward, convolution_direct, *context_.src_md, +#endif *context_.filter_md, *context_.bias_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, padding_kind::zero)); } else { context_.fwd_desc.reset(new convolution_forward::desc( +#ifdef ENABLE_MKLDNN_V1 + prop_kind::forward, mkldnn::algorithm::convolution_direct, + *context_.src_md, +#else prop_kind::forward, convolution_direct, *context_.src_md, +#endif *context_.filter_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, padding_kind::zero)); @@ -246,7 +320,12 @@ class MklConvFwdPrimitive : public MklPrimitive { float op_scale = post_op_param.param[0]; float op_alpha = post_op_param.param[1]; float op_beta = post_op_param.param[2]; +#ifdef ENABLE_MKLDNN_V1 + post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu, + op_alpha, +#else post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha, +#endif op_beta); } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); @@ -271,21 +350,54 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); } +#ifndef ENABLE_MKLDNN_V1 // Store the expected memory format context_.src_fmt = static_cast( context_.fwd_pd.get()->src_primitive_desc().desc().data.format); context_.filter_fmt = static_cast( context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); +#endif +#ifdef ENABLE_MKLDNN_V1 // Create memory primitive based on dummy data + context_.src_mem.reset( + new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); + context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), + cpu_engine_, DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); +#else context_.src_mem.reset( new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); context_.filter_mem.reset( new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); context_.dst_mem.reset( new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); +#endif +#ifdef ENABLE_MKLDNN_V1 + // Create convolution primitive and add it to net + if (!convFwdDims.bias_dims.empty()) { + context_.bias_mem.reset(new memory( + {{convFwdDims.bias_dims}, MklDnnType(), memory::format_tag::x}, + cpu_engine_, DummyData)); + context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); + context_.fwd_primitives_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, + {MKLDNN_ARG_BIAS, *context_.bias_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}}); + } else { + context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); + context_.fwd_primitives_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}}); + } + context_.fwd_primitives.push_back(*context_.conv_fwd); + return; +#else // Create convolution primitive and add it to net if (!convFwdDims.bias_dims.empty()) { context_.bias_mem.reset(new memory( @@ -303,6 +415,7 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.fwd_primitives.push_back(*context_.conv_fwd); return; +#endif } struct ConvFwdContext context_; @@ -450,17 +563,15 @@ class MklConvOp : public OpKernel { OP_REQUIRES(context, dilations_.size() == 5, errors::InvalidArgument("Dilation rates field must " "specify 5 dimensions")); - OP_REQUIRES(context, - (GetTensorDim(dilations_, data_format_, 'N') == 1 && - GetTensorDim(dilations_, data_format_, 'C') == 1), + OP_REQUIRES(context, (GetTensorDim(dilations_, data_format_, 'N') == 1 && + GetTensorDim(dilations_, data_format_, 'C') == 1), errors::InvalidArgument( "Current implementation does not yet support " "dilations rates in the batch and depth dimensions.")); OP_REQUIRES( - context, - (GetTensorDim(dilations_, data_format_, '0') > 0 && - GetTensorDim(dilations_, data_format_, '1') > 0 && - GetTensorDim(dilations_, data_format_, '2') > 0), + context, (GetTensorDim(dilations_, data_format_, '0') > 0 && + GetTensorDim(dilations_, data_format_, '1') > 0 && + GetTensorDim(dilations_, data_format_, '2') > 0), errors::InvalidArgument("Dilated rates should be larger than 0.")); } } @@ -566,6 +677,12 @@ class MklConvOp : public OpKernel { auto tf_fmt = is_conv2d ? TFDataFormatToMklDnnDataFormat(data_format_) : TFDataFormatToMklDnn3DDataFormat(data_format_); +#ifdef ENABLE_MKLDNN_V1 + auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); + // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU + CHECK_NE(mkl_fmt_tag, memory::format_tag::undef); +#endif + // If input is in MKL layout, then simply grab the layout; otherwise, // construct TF layout for input. // For constructing TF layout for input, although input shape (src_dims) @@ -573,18 +690,28 @@ class MklConvOp : public OpKernel { // TF layout depending on the data format: // Conv2D: NHWC or NCHW // Conv3D: NDHWC or NCDHW - auto src_md = src_mkl_shape.IsMklTensor() - ? src_mkl_shape.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), tf_fmt); + auto src_md = + src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() +#ifdef ENABLE_MKLDNN_V1 + : memory::desc(src_dims, MklDnnType(), mkl_fmt_tag); +#else + : memory::desc(src_dims, MklDnnType(), tf_fmt); +#endif src.SetUsrMem(src_md, &src_tensor); +#ifdef ENABLE_MKLDNN_V1 // Although filter shape (filter_dims) required is in MKL-DNN order, // the layout is Tensorflow's layout (HWIO) and (HWIGO) for // depthwise/group convolutions. - + auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo + : memory::format_tag::hwio) + : memory::format_tag::dhwio; +#else auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo : memory::format::hwio) : memory::format::dhwio; +#endif DCHECK(!filter_mkl_shape.IsMklTensor()); auto filter_md = @@ -643,6 +770,51 @@ class MklConvOp : public OpKernel { // Check whether src and filter need to be reordered Tinput* src_data = nullptr; +#ifdef ENABLE_MKLDNN_V1 + if (src_md != conv_fwd_pd->src_desc()) { + // Reorder src + src.SetUsrMem(src_md, &src_tensor); + src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_); + src_data = static_cast(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast( + const_cast(src_tensor.flat().data())); + } + + Tfilter* filter_data = nullptr; + if (filter_md != conv_fwd_pd->weights_desc()) { + bool is_filter_cached = false; + // If filter is a constant, we can avoid the conversion of filter from + // Tensorflow format to MKL format by caching the filter when it is + // converted for the first time. This cached filter can then be reused + // in subsequent iterations. + if (is_filter_const_) { + if (IsFilterCacheEmpty(context)) { + // Cache filter if it is not already cached. + CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, + filter, filter_md, filter_mkl_shape); + } + filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc()); + is_filter_cached = (filter_data != nullptr); + } + if (!is_filter_cached) { + filter.SetUsrMem(filter_md, &filter_tensor); + if (filter_out_tensor == nullptr) { + filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), + cpu_engine_); + } else { + filter.CheckReorderToOpMem( + conv_fwd_pd->weights_desc(), + filter.GetTensorBuffer(filter_out_tensor), cpu_engine_); + } + filter_data = + static_cast(filter.GetOpMem().get_data_handle()); + } + } else { + filter_data = static_cast( + const_cast(filter_tensor.flat().data())); + } +#else if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) { // Reorder src src.SetUsrMem(src_md, &src_tensor); @@ -687,6 +859,7 @@ class MklConvOp : public OpKernel { filter_data = static_cast( const_cast(filter_tensor.flat().data())); } +#endif // Execute convolution if (fuse_biasadd_) { @@ -805,6 +978,35 @@ class MklConvOp : public OpKernel { return nullptr; } +#ifdef ENABLE_MKLDNN_V1 + virtual void AllocateOutputTensor(OpKernelContext* context, + const ConvFwdPd& conv_prim_desc, + const memory::dims& output_dims_mkl_order, + MklTensorFormat output_tf_format, + Tensor** output_tensor) { + CHECK_NOTNULL(output_tensor); + auto dst_md = conv_prim_desc.dst_desc(); + + if (!std::is_same::value) { + dst_md.data.data_type = + static_cast(MklDnnType()); + } + // Allocate shape of Mkl tensor. + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_md); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + // Allocate shape of TF tensor. + TensorShape output_tf_shape; + output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput))); + + AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, + output_tf_shape, output_mkl_shape); + } +#else virtual void AllocateOutputTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, const memory::dims& output_dims_mkl_order, @@ -862,8 +1064,13 @@ class MklConvOp : public OpKernel { } } } +#endif +#ifdef ENABLE_MKLDNN_V1 + engine cpu_engine_ = engine(engine::kind::cpu, 0); +#else engine cpu_engine_ = engine(engine::cpu, 0); +#endif private: std::vector strides_; @@ -892,8 +1099,105 @@ class MklConvOp : public OpKernel { const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; const int kDilationH = 0, kDilationW = 1; +#ifdef ENABLE_MKLDNN_V1 // Allocate persistent tensors for cached filter data and // cached filter memory descriptor (data format) + void AllocatePersistentTensor(OpKernelContext* context, + const ConvFwdPd& conv_prim_desc, + Tensor** filter_tensor, + const MklDnnShape& filter_mkl_shape) { + DCHECK(filter_tensor); + TensorShape filter_tf_shape; + filter_tf_shape.AddDim( + (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter))); + OP_REQUIRES_OK(context, context->allocate_persistent( + DataTypeToEnum::value, filter_tf_shape, + &cached_filter_data_ptensor_, filter_tensor)); + + Tensor* second_tensor = nullptr; + TensorShape filter_mkl_format; + filter_mkl_format.AddDim(sizeof(filter_mkl_shape.GetTfDataFormat()) / + sizeof(DT_INT32)); + OP_REQUIRES_OK(context, context->allocate_persistent( + DT_INT32, filter_mkl_format, + &cached_filter_md_ptensor_, &second_tensor)); + second_tensor->scalar()() = + static_cast(filter_mkl_shape.GetTfDataFormat()); + } + + void AllocateFilterOutputTensor(OpKernelContext* context, + const ConvFwdPd& conv_prim_desc, + const memory::dims& filter_dims_tf_order, + Tensor** filter_tensor) { + CHECK_NOTNULL(filter_tensor); + auto filter_md = conv_prim_desc.weights_desc(); + + // Allocate shape of Mkl tensor. + MklDnnShape filter_mkl_shape; + filter_mkl_shape.SetMklTensor(true); + filter_mkl_shape.SetMklLayout(&filter_md); + filter_mkl_shape.SetElemType(MklDnnType()); + + // The format of the filter is actually OIhw8i8o, but TF doesn't support + // this format. Just use format::blocked for now because the layout + // is stored in the MKL data. + filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), + filter_dims_tf_order, + MklTensorFormat::FORMAT_UNDEF); + + // Allocate the data space for the filter to propagate as TF tensor. + TensorShape filter_tf_shape; + filter_tf_shape.AddDim((filter_md.get_size() / sizeof(Tfilter))); + + AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, + filter_tf_shape, filter_mkl_shape); + } + + // Prepare and execute net - checks for input and output reorders. + void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc, + MklDnnData* src, + MklDnnData* filter, + MklDnnData* bias, + MklDnnData* output, + Tensor* filter_out_tensor) { + CHECK_NOTNULL(filter_out_tensor); + + // Create reorders between user layout and MKL layout if it is needed and + // add it to the net before convolution. No need to check for output + // reorder as we propagate output layout to the next layer. + src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_); + + // rather than re-order to a temp buffer, reorder directly to the + // filter output tensor + filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(), + filter->GetTensorBuffer(filter_out_tensor)); + + // Create convolution primitive and add it to net. + std::vector net; + std::vector> net_args; + if (bias) { + DCHECK(fuse_biasadd_); + net.push_back(convolution_forward(conv_prim_desc)); + net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, + {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, + {MKLDNN_ARG_BIAS, bias->GetOpMem()}, + {MKLDNN_ARG_DST, output->GetOpMem()}}); + } else { + DCHECK(!fuse_biasadd_); + net.push_back(convolution_forward(conv_prim_desc)); + net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, + {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, + {MKLDNN_ARG_DST, output->GetOpMem()}}); + } + stream cpu_stream(cpu_engine_); + + CHECK_EQ(net.size(), net_args.size()); + for (size_t i = 0; i < net.size(); ++i) { + net.at(i).execute(cpu_stream, net_args.at(i)); + } + cpu_stream.wait(); + } +#else void AllocatePersistentTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, Tensor** filter_tensor) { @@ -979,6 +1283,7 @@ class MklConvOp : public OpKernel { stream(stream::kind::eager).submit(net).wait(); } +#endif // LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot // be acquired before entering the function, since it is acquired @@ -990,6 +1295,37 @@ class MklConvOp : public OpKernel { return (cached_filter_data_tensor.NumElements() == 0); } +#ifdef ENABLE_MKLDNN_V1 + // Cache the converted filter in a persistent tensor. + // Only one thread can execute this method at any given time. + void CacheFilter(OpKernelContext* context, + const std::shared_ptr& conv_fwd_pd, + Tfilter* filter_data, const Tensor& filter_tensor, + MklDnnData& filter, const memory::desc& filter_md, + const MklDnnShape& filter_mkl_shape) LOCKS_EXCLUDED(mu_) { + mutex_lock lock(mu_); + const Tensor& cached_filter_data_tensor = + *cached_filter_data_ptensor_.AccessTensor(context); + + // If filter is already cached, there's nothing to do. + if (cached_filter_data_tensor.NumElements() > 0) { + return; + } + + // Otherwise, cache filter + filter.SetUsrMem(filter_md, &filter_tensor); + filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_desc(), + this->cpu_engine_); + filter_data = static_cast(filter.GetOpMem().get_data_handle()); + + Tensor* filter_tensor_ptr = nullptr; + AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr, + filter_mkl_shape); + void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr); + size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size(); + memcpy(cached_filter_data, filter_data, cached_filter_data_size); + } +#else // Cache the converted filter in a persistent tensor. // Only one thread can execute this method at any given time. void CacheFilter(OpKernelContext* context, @@ -1018,7 +1354,45 @@ class MklConvOp : public OpKernel { filter.GetOpMem().get_primitive_desc().get_size(); memcpy(cached_filter_data, filter_data, cached_filter_data_size); } +#endif +#ifdef ENABLE_MKLDNN_V1 + bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, + const Tensor& cached_filter_md) { + auto filter_md_data = filter_md.data; + const char* filter_data = reinterpret_cast(&filter_md_data); + + auto cached_filter_md_data = cached_filter_md.scalar()(); + const char* cached_filter_data = + reinterpret_cast(&cached_filter_md_data); + + for (size_t i = 0; i < sizeof(filter_md_data); ++i) { + if (*filter_data++ != *cached_filter_data++) { + return false; + } + } + return true; + } + + Tfilter* GetCachedFilter(OpKernelContext* context, + const memory::desc& filter_md) LOCKS_EXCLUDED(mu_) { + tf_shared_lock lock(mu_); + const Tensor& cached_filter_data = + *cached_filter_data_ptensor_.AccessTensor(context); + const Tensor& cached_filter_md = + *cached_filter_md_ptensor_.AccessTensor(context); + + // Check if the memory descriptor of the cached weights is same as + // filter_mf. If so, we can used the cached weights; otherwise + // return NULL. + if (cached_filter_md.scalar().size() && + AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) { + return static_cast( + const_cast(cached_filter_data.flat().data())); + } + return nullptr; + } +#else Tfilter* GetCachedFilter(OpKernelContext* context, const memory::format& filter_mf) LOCKS_EXCLUDED(mu_) { @@ -1039,6 +1413,7 @@ class MklConvOp : public OpKernel { } return nullptr; } +#endif }; // Base class for fused convolution forward operations @@ -1294,6 +1669,9 @@ class MklQuantizedConv2DOp const float* max_filter = max_filter_vector.flat().data(); std::vector net; +#ifdef ENABLE_MKLDNN_V1 + std::vector> net_args; +#endif if (bias_enabled) { if (std::is_same::value) { return static_cast( @@ -1315,6 +1693,32 @@ class MklQuantizedConv2DOp } else { bias_attr.set_output_scales(1, scales); } +#ifdef ENABLE_MKLDNN_V1 + auto bias_md = + memory::desc({static_cast(bias_tensor.NumElements())}, + MklDnnType(), memory::format_tag::x); + + void* bias_buf = static_cast( + const_cast(bias_tensor.flat().data())); + input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); + scaled_bias_ = new memory(conv_fwd_pd->bias_desc(), this->cpu_engine_); + auto reorder_desc = mkldnn::reorder::primitive_desc( + this->cpu_engine_, input_bias_->get_desc(), this->cpu_engine_, + scaled_bias_->get_desc(), bias_attr); + net.push_back(mkldnn::reorder(reorder_desc)); + net_args.push_back({{MKLDNN_ARG_FROM, *input_bias_}, + {MKLDNN_ARG_TO, *scaled_bias_}}); + + CHECK_EQ(net.size(), net_args.size()); + + stream cpu_stream(this->cpu_engine_); + for (size_t i = 0; i < net.size(); ++i) { + net.at(i).execute(cpu_stream, net_args.at(i)); + } + cpu_stream.wait(); + + return reinterpret_cast(scaled_bias_->get_data_handle()); +#else auto bias_pd = memory::primitive_desc({{static_cast(bias_tensor.NumElements())}, MklDnnType(), @@ -1331,6 +1735,7 @@ class MklQuantizedConv2DOp net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_)); stream(stream::kind::eager).submit(net).wait(); return reinterpret_cast(scaled_bias_->get_data_handle()); +#endif } else { return nullptr; } @@ -1431,7 +1836,11 @@ class MklQuantizedConv2DSumReluOp void AllocateOutputTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, const memory::dims& output_dims_mkl_order, +#ifdef ENABLE_MKLDNN_V1 + MklTensorFormat output_tf_format, +#else memory::format output_tf_format, +#endif Tensor** output_tensor) override { int summand_idx = context->num_inputs() / 2 - 1; if (std::is_same::value) { @@ -1499,6 +1908,36 @@ class MklQuantizedConv2DSumReluOp } else { reorder_attr.set_output_scales(2, scales); } +#ifdef ENABLE_MKLDNN_V1 + auto summand_md = + summand_mkl_shape.IsMklTensor() + ? summand_mkl_shape.GetMklLayout() + : memory::desc(output_dims_mkl_order, MklDnnType(), + memory::format_tag::nhwc); + void* summand_buf = + static_cast(const_cast(summand.flat().data())); + void* dst_buf = + static_cast((*output_tensor)->flat().data()); + summand_ = new memory(summand_md, this->cpu_engine_, summand_buf); + dst_ = new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf); + auto reorder_desc = mkldnn::reorder::primitive_desc( + this->cpu_engine_, summand_md, this->cpu_engine_, + conv_prim_desc.dst_desc(), reorder_attr); + + std::vector net; + std::vector> net_args; + + net.push_back(mkldnn::reorder(reorder_desc)); + net_args.push_back({{MKLDNN_ARG_FROM, *summand_}, + {MKLDNN_ARG_TO, *dst_}}); + CHECK_EQ(net.size(), net_args.size()); + + stream cpu_stream(this->cpu_engine_); + for (size_t i = 0; i < net.size(); ++i) { + net.at(i).execute(cpu_stream, net_args.at(i)); + } + cpu_stream.wait(); +#else auto summand_md = summand_mkl_shape.IsMklTensor() ? summand_mkl_shape.GetMklLayout() @@ -1517,6 +1956,7 @@ class MklQuantizedConv2DSumReluOp std::vector net; net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_)); stream(stream::kind::eager).submit(net).wait(); +#endif } memory* summand_ = nullptr; diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index c12a4ff0f0c..2399f5213a3 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -40,7 +40,9 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#ifndef ENABLE_MKLDNN_V1 using mkldnn::convolution_direct; +#endif using mkldnn::convolution_forward; using mkldnn::prop_kind; using mkldnn::stream; @@ -136,8 +138,13 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_cols_raw, "Input cols too large"); int input_cols = static_cast(input_cols_raw); +#ifdef ENABLE_MKLDNN_V1 + // MKL-DNN always requires input in NCHW format Conv2D. + std::vector mkldnn_sizes(4, -1); +#else // MKL-DNN always requires input in NCHW format Conv2D. std::vector mkldnn_sizes(4, -1); +#endif mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; @@ -160,8 +167,13 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_cols_raw, "Input cols too large"); int input_cols = static_cast(input_cols_raw); +#ifdef ENABLE_MKLDNN_V1 + // MKL-DNN always requires input in NCDHW format for Conv3D. + std::vector mkldnn_sizes(5, -1); +#else // MKL-DNN always requires input in NCDHW format for Conv3D. std::vector mkldnn_sizes(5, -1); +#endif mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch; mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes; @@ -196,9 +208,8 @@ class MklDnnConvUtil { filter_shape.DebugString())); for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) { - OP_REQUIRES(context_, - FastBoundsCheck(filter_shape.dim_size(i), - std::numeric_limits::max()), + OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i), + std::numeric_limits::max()), errors::InvalidArgument("filter too large")); } @@ -225,7 +236,11 @@ class MklDnnConvUtil { // GOIHW = (group, out_depth, in_depth, rows, cols) // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1 if (is_depthwise) { +#ifdef ENABLE_MKLDNN_V1 + std::vector mkldnn_sizes(5, -1); +#else std::vector mkldnn_sizes(5, -1); +#endif mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth; mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth; mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1; @@ -234,7 +249,11 @@ class MklDnnConvUtil { *filter_dims = mkldnn_sizes; } else { +#ifdef ENABLE_MKLDNN_V1 + std::vector mkldnn_sizes(4, -1); +#else std::vector mkldnn_sizes(4, -1); +#endif mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth; mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth; mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; @@ -260,9 +279,15 @@ class MklDnnConvUtil { int filter_out_depth = static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_O)); +#ifdef ENABLE_MKLDNN_V1 + // MKL-DNN always needs filter in OIDHW format. + // OIDHW = (out_depth, in_depth, planes, rows, cols) + std::vector mkldnn_sizes(5, -1); +#else // MKL-DNN always needs filter in OIDHW format. // OIDHW = (out_depth, in_depth, planes, rows, cols) std::vector mkldnn_sizes(5, -1); +#endif mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; @@ -451,15 +476,24 @@ class MklDnnConvUtil { *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); if (is_conv2d) { +#ifdef ENABLE_MKLDNN_V1 + // For Conv2D, MKL-DNN always needs output in NCHW format. + std::vector mkldnn_sizes(4, -1); +#else // For Conv2D, MKL-DNN always needs output in NCHW format. std::vector mkldnn_sizes(4, -1); +#endif mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; mkldnn_sizes[MklDnnDims::Dim_H] = static_cast(out_rows); mkldnn_sizes[MklDnnDims::Dim_W] = static_cast(out_cols); *output_dims_mkl_order = mkldnn_sizes; } else { +#ifdef ENABLE_MKLDNN_V1 + std::vector mkldnn_sizes(5, -1); +#else std::vector mkldnn_sizes(5, -1); +#endif mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch; mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast(out_planes); From 3608a971bb3413e55494497e6b30a3e1b46aec5b Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Wed, 10 Jul 2019 10:25:27 -0700 Subject: [PATCH 2/6] Changed CHECK to DCHECK. --- tensorflow/core/kernels/mkl_conv_ops.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 39cc4da3ce0..b9ef04413c9 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -133,7 +133,7 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle( static_cast(const_cast(dst_data))); #ifdef ENABLE_MKLDNN_V1 - CHECK_EQ(context_.fwd_primitives.size(), + DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { context_.fwd_primitives.at(i).execute(*context_.fwd_stream, @@ -165,7 +165,7 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle( static_cast(const_cast(dst_data))); #ifdef ENABLE_MKLDNN_V1 - CHECK_EQ(context_.fwd_primitives.size(), + DCHECK_EQ(context_.fwd_primitives.size(), context_.fwd_primitives_args.size()); for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { context_.fwd_primitives.at(i).execute(*context_.fwd_stream, @@ -680,7 +680,7 @@ class MklConvOp : public OpKernel { #ifdef ENABLE_MKLDNN_V1 auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU - CHECK_NE(mkl_fmt_tag, memory::format_tag::undef); + DCHECK_NE(mkl_fmt_tag, memory::format_tag::undef); #endif // If input is in MKL layout, then simply grab the layout; otherwise, @@ -1191,7 +1191,7 @@ class MklConvOp : public OpKernel { } stream cpu_stream(cpu_engine_); - CHECK_EQ(net.size(), net_args.size()); + DCHECK_EQ(net.size(), net_args.size()); for (size_t i = 0; i < net.size(); ++i) { net.at(i).execute(cpu_stream, net_args.at(i)); } @@ -1709,7 +1709,7 @@ class MklQuantizedConv2DOp net_args.push_back({{MKLDNN_ARG_FROM, *input_bias_}, {MKLDNN_ARG_TO, *scaled_bias_}}); - CHECK_EQ(net.size(), net_args.size()); + DCHECK_EQ(net.size(), net_args.size()); stream cpu_stream(this->cpu_engine_); for (size_t i = 0; i < net.size(); ++i) { @@ -1930,7 +1930,7 @@ class MklQuantizedConv2DSumReluOp net.push_back(mkldnn::reorder(reorder_desc)); net_args.push_back({{MKLDNN_ARG_FROM, *summand_}, {MKLDNN_ARG_TO, *dst_}}); - CHECK_EQ(net.size(), net_args.size()); + DCHECK_EQ(net.size(), net_args.size()); stream cpu_stream(this->cpu_engine_); for (size_t i = 0; i < net.size(); ++i) { From fdf9ee647ef267f847d11173d3c391e57762a9c9 Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Wed, 10 Jul 2019 11:05:12 -0700 Subject: [PATCH 3/6] Ran Clang format checks. --- tensorflow/core/kernels/mkl_conv_ops.cc | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b9ef04413c9..d7a457e3729 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -134,7 +134,7 @@ class MklConvFwdPrimitive : public MklPrimitive { static_cast(const_cast(dst_data))); #ifdef ENABLE_MKLDNN_V1 DCHECK_EQ(context_.fwd_primitives.size(), - context_.fwd_primitives_args.size()); + context_.fwd_primitives_args.size()); for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { context_.fwd_primitives.at(i).execute(*context_.fwd_stream, context_.fwd_primitives_args.at(i)); @@ -166,7 +166,7 @@ class MklConvFwdPrimitive : public MklPrimitive { static_cast(const_cast(dst_data))); #ifdef ENABLE_MKLDNN_V1 DCHECK_EQ(context_.fwd_primitives.size(), - context_.fwd_primitives_args.size()); + context_.fwd_primitives_args.size()); for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { context_.fwd_primitives.at(i).execute(*context_.fwd_stream, context_.fwd_primitives_args.at(i)); @@ -387,13 +387,15 @@ class MklConvFwdPrimitive : public MklPrimitive { {{MKLDNN_ARG_SRC, *context_.src_mem}, {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, {MKLDNN_ARG_BIAS, *context_.bias_mem}, - {MKLDNN_ARG_DST, *context_.dst_mem}}); + { MKLDNN_ARG_DST, + *context_.dst_mem }}); } else { context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); context_.fwd_primitives_args.push_back( {{MKLDNN_ARG_SRC, *context_.src_mem}, {MKLDNN_ARG_WEIGHTS, *context_.filter_mem}, - {MKLDNN_ARG_DST, *context_.dst_mem}}); + { MKLDNN_ARG_DST, + *context_.dst_mem }}); } context_.fwd_primitives.push_back(*context_.conv_fwd); return; @@ -804,7 +806,7 @@ class MklConvOp : public OpKernel { cpu_engine_); } else { filter.CheckReorderToOpMem( - conv_fwd_pd->weights_desc(), + conv_fwd_pd->weights_desc(), filter.GetTensorBuffer(filter_out_tensor), cpu_engine_); } filter_data = @@ -1181,13 +1183,15 @@ class MklConvOp : public OpKernel { net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, {MKLDNN_ARG_BIAS, bias->GetOpMem()}, - {MKLDNN_ARG_DST, output->GetOpMem()}}); + { MKLDNN_ARG_DST, + output->GetOpMem() }}); } else { DCHECK(!fuse_biasadd_); net.push_back(convolution_forward(conv_prim_desc)); net_args.push_back({{MKLDNN_ARG_SRC, src->GetOpMem()}, {MKLDNN_ARG_WEIGHTS, filter->GetOpMem()}, - {MKLDNN_ARG_DST, output->GetOpMem()}}); + { MKLDNN_ARG_DST, + output->GetOpMem() }}); } stream cpu_stream(cpu_engine_); @@ -1707,7 +1711,8 @@ class MklQuantizedConv2DOp scaled_bias_->get_desc(), bias_attr); net.push_back(mkldnn::reorder(reorder_desc)); net_args.push_back({{MKLDNN_ARG_FROM, *input_bias_}, - {MKLDNN_ARG_TO, *scaled_bias_}}); + { MKLDNN_ARG_TO, + *scaled_bias_ }}); DCHECK_EQ(net.size(), net_args.size()); @@ -1929,7 +1934,8 @@ class MklQuantizedConv2DSumReluOp net.push_back(mkldnn::reorder(reorder_desc)); net_args.push_back({{MKLDNN_ARG_FROM, *summand_}, - {MKLDNN_ARG_TO, *dst_}}); + { MKLDNN_ARG_TO, + *dst_ }}); DCHECK_EQ(net.size(), net_args.size()); stream cpu_stream(this->cpu_engine_); From 7c0bcbe0d19ead1b699c9fdd9e9746d72668e653 Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Mon, 29 Jul 2019 13:12:45 -0700 Subject: [PATCH 4/6] Addressed review comments. --- tensorflow/core/graph/mkl_layout_pass.cc | 6 +- tensorflow/core/kernels/mkl_conv_ops.cc | 855 +++++++++-------------- tensorflow/core/util/mkl_util.h | 27 +- 3 files changed, 371 insertions(+), 517 deletions(-) diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 2a5b50398e8..8a1975ec2ca 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -391,7 +391,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { {csinfo_.conjugate_transpose, mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), CopyAttrsTranspose, AlwaysRewrite, kRewriteForOpNameChange}); -#endif // ENABLE_MKLDNN_V1 +#endif // !ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConvCheckConstFilter, AlwaysRewrite, @@ -651,7 +651,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), CopyAttrsRequantize, AlwaysRewrite, kRewriteForLayoutPropagation}); -#endif // ENABLE_MKLDNN_V1 +#endif // !ENABLE_MKLDNN_V1 // Disable these two MKL operators for now due to some test failures caused // by these two ops /* @@ -765,7 +765,7 @@ rinfo_.push_back({csinfo_.tanh_grad, // CheckForMklOp FuseConv3D, CopyAttrsConv}); -#endif // ENABLE_MKLDNN_V1 +#endif // !ENABLE_MKLDNN_V1 } // Standard interface to run pass diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 0e68a46e45b..b5b7d6bf4d7 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -47,15 +47,97 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +using mkldnn::convolution_forward; +using mkldnn::memory; using mkldnn::prop_kind; using mkldnn::stream; -using mkldnn::convolution_forward; -#ifndef ENABLE_MKLDNN_V1 -using mkldnn::convolution_direct; -#endif namespace tensorflow { +#ifdef ENABLE_MKLDNN_V1 +#define ADD_MD add_md +#define ALGORITHM mkldnn::algorithm +#define ALGORITHM_UNDEF ALGORITHM::undef +#define CPU_STREAM(engine) stream(engine) +#define DATA_WITH_ENGINE(data, engine) data, engine +#define DST_MD dst_md +#define ENGINE_CPU engine::kind::cpu +#define GET_DESC get_desc() +#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \ + { {dims}, MklDnnType(), memory::format_tag::fm } +#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc() +#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd->weights_desc() +#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \ + GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) +#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \ + filter_md != op_pd->weights_desc() +#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \ + src_md != op_pd->src_desc() +#define MEMORY_CONSTRUCTOR(mem_desc, engine, data) \ + memory(mem_desc, engine, data) +#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \ + memory(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine, data) +#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_desc, engine) \ + memory(mem_desc, engine) +#define MEMORY_DESC memory::desc +#define MEMORY_FORMAT mkldnn::memory::format_tag +#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \ + memory::desc({dims}, MklDnnType(), memory::format_tag::fm) +#define MEMORY_PD_WITHOUT_DATA(md, engine) md, engine +#define MKL_TENSOR_FORMAT MklTensorFormat +#define MKL_TENSOR_FORMAT_BLOCKED MklTensorFormat::FORMAT_BLOCKED +#define MKL_TENSOR_FORMAT_IN_C MKL_TENSOR_FORMAT +#define PRIMITIVE_DESC_BIAS bias_desc() +#define PRIMITIVE_DESC_DST dst_desc() +#define PRIMITIVE_DESC_SRC src_desc() +#define PRIMITIVE_DESC_WEIGHTS weights_desc() +#define REORDER_PD_CONSTRUCTOR(src_md, dst_md, engine) \ + mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md) +#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_md, dst_md, engine, prim_attr) \ + mkldnn::reorder::primitive_desc(engine, src_md, engine, dst_md, prim_attr) +#define SUMMAND_MD summand_md +#else +#define ADD_MD add_pd +#define ALGORITHM mkldnn +#define ALGORITHM_UNDEF ALGORITHM::algorithm_undef +#define CPU_STREAM(engine) stream(stream::kind::eager) +#define DATA_WITH_ENGINE(data, engine) data +#define DST_MD dst_pd +#define ENGINE_CPU engine::cpu +#define GET_DESC get_primitive_desc() +#define GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm) \ + { {dims}, MklDnnType(), memory::format::fm } +#define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc() +#define GET_WEIGHTS_DESC_FROM_OP_PD(op_pd) op_pd.get()->weights_primitive_desc() +#define GET_WEIGHTS_FORMAT_FROM_OP_PD(op_pd, op_primitive) \ + op_primitive->GetFilterMemoryFormat() +#define IS_FILTER_REORDER_NEEDED(filter_md, op_pd, op_primitive) \ + filter_md.data.format != op_primitive->GetFilterMemoryFormat() +#define IS_SRC_REORDER_NEEDED(src_md, op_pd, op_primitive) \ + src_md.data.format != op_primitive->GetSrcMemoryFormat() +#define MEMORY_CONSTRUCTOR(mem_pd, engine, data) memory(mem_pd, data) +#define MEMORY_CONSTRUCTOR_USING_MEM_PD(dims, type, fm, engine, data) \ + memory({GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine}, data) +#define MEMORY_CONSTRUCTOR_WITHOUT_DATA(mem_pd, engine) memory(mem_pd) +#define MEMORY_DESC memory::format +#define MEMORY_FORMAT mkldnn::memory::format +#define MEMORY_PD_CONSTRUCTOR(dims, type, fm, engine) \ + memory::primitive_desc(GET_MEMORY_DESC_CONSTRUCTOR(dims, type, fm), engine) +#define MEMORY_PD_WITHOUT_DATA(pd, engine) pd +#define MKL_TENSOR_FORMAT memory::format +#define MKL_TENSOR_FORMAT_BLOCKED memory::format::blocked +#define MKL_TENSOR_FORMAT_IN_C mkldnn_memory_format_t +#define PRIMITIVE_DESC_BIAS bias_primitive_desc() +#define PRIMITIVE_DESC_DST dst_primitive_desc() +#define PRIMITIVE_DESC_SRC src_primitive_desc() +#define PRIMITIVE_DESC_WEIGHTS weights_primitive_desc() +#define REORDER_PD_CONSTRUCTOR(src_pd, dst_pd, engine) \ + mkldnn::reorder::primitive_desc(src_pd, dst_pd) +#define REORDER_PD_CONSTRUCTOR_WITH_ATTR(src_pd, dst_pd, engine, prim_attr) \ + mkldnn::reorder::primitive_desc(src_pd, dst_pd, prim_attr) +#define SUMMAND_MD summand_pd +#endif // ENABLE_MKLDNN_V1 + // This structure aggregates multiple inputs to Conv2DFwd* methods. struct MklConvFwdParams { memory::dims src_dims; @@ -95,25 +177,14 @@ typedef mkldnn::convolution_forward::primitive_desc ConvFwdPd; template class MklConvFwdPrimitive : public MklPrimitive { public: -#ifdef ENABLE_MKLDNN_V1 explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) - : cpu_engine_(engine::kind::cpu, 0) { - context_.fwd_stream.reset(new stream(cpu_engine_)); - // Create conv primitive + : cpu_engine_(ENGINE_CPU, 0) { + context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); + // Create convolution primitive if (context_.conv_fwd == nullptr) { Setup(convFwdDims); } } -#else - explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims) - : cpu_engine_(engine::cpu, 0) { - context_.fwd_stream.reset(new stream(stream::kind::eager)); - // Create conv primitive - if (context_.conv_fwd == nullptr) { - Setup(convFwdDims); - } - } -#endif ~MklConvFwdPrimitive() {} @@ -128,8 +199,10 @@ class MklConvFwdPrimitive : public MklPrimitive { static_cast(const_cast(src_data))); context_.filter_mem->set_data_handle( static_cast(const_cast(filter_data))); - context_.bias_mem->set_data_handle( - static_cast(const_cast(bias_data))); + if (bias_data != nullptr) { + context_.bias_mem->set_data_handle( + static_cast(const_cast(bias_data))); + } context_.dst_mem->set_data_handle( static_cast(const_cast(dst_data))); #ifdef ENABLE_MKLDNN_V1 @@ -141,15 +214,15 @@ class MklConvFwdPrimitive : public MklPrimitive { } #else context_.fwd_stream->submit(context_.fwd_primitives); -#endif +#endif // ENABLE_MKLDNN_V1 - // After exec, set data handle back + // After execution, set data handle back context_.src_mem->set_data_handle(DummyData); context_.filter_mem->set_data_handle(DummyData); - context_.bias_mem->set_data_handle(DummyData); + if (bias_data != nullptr) { + context_.bias_mem->set_data_handle(DummyData); + } context_.dst_mem->set_data_handle(DummyData); - - return; } // Convolution forward execute without bias @@ -158,36 +231,15 @@ class MklConvFwdPrimitive : public MklPrimitive { // dst_data: output data buffer of dst void Execute(const Tinput* src_data, const Tfilter* filter_data, const Toutput* dst_data) { - context_.src_mem->set_data_handle( - static_cast(const_cast(src_data))); - context_.filter_mem->set_data_handle( - static_cast(const_cast(filter_data))); - context_.dst_mem->set_data_handle( - static_cast(const_cast(dst_data))); -#ifdef ENABLE_MKLDNN_V1 - DCHECK_EQ(context_.fwd_primitives.size(), - context_.fwd_primitives_args.size()); - for (size_t i = 0; i < context_.fwd_primitives.size(); ++i) { - context_.fwd_primitives.at(i).execute(*context_.fwd_stream, - context_.fwd_primitives_args.at(i)); - } -#else - context_.fwd_stream->submit(context_.fwd_primitives); -#endif - - // After execution, set data handle back - context_.src_mem->set_data_handle(DummyData); - context_.filter_mem->set_data_handle(DummyData); - context_.dst_mem->set_data_handle(DummyData); + Execute(src_data, filter_data, nullptr, dst_data); } #ifndef ENABLE_MKLDNN_V1 - // In MKL-DNN v1.0, memory format tags only provide a partial description - // of the memory layout. Hence, these functions are disabled for v1.0. + // In MKL-DNN v1.x, memory format tags only provide a partial description + // of the memory layout. Hence, these functions are disabled for v1.x. memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } - memory::format GetFilterMemoryFormat() const { return context_.filter_fmt; } -#endif +#endif // !ENABLE_MKLDNN_V1 std::shared_ptr GetPrimitiveDesc() const { return context_.fwd_pd; @@ -200,15 +252,15 @@ class MklConvFwdPrimitive : public MklPrimitive { // Expected memory format for this primitive instance memory::format src_fmt; memory::format filter_fmt; -#endif +#endif // !ENABLE_MKLDNN_V1 - // MKLDNN memory + // MKL-DNN memory std::shared_ptr src_mem; std::shared_ptr filter_mem; std::shared_ptr bias_mem; std::shared_ptr dst_mem; - // Desc & prmitive desc + // Desc & primitive desc std::shared_ptr fwd_desc; // Memory desc @@ -226,14 +278,14 @@ class MklConvFwdPrimitive : public MklPrimitive { #ifdef ENABLE_MKLDNN_V1 std::vector> fwd_primitives_args; -#endif +#endif // ENABLE_MKLDNN_V1 ConvFwdContext() : #ifndef ENABLE_MKLDNN_V1 src_fmt(memory::format::any), filter_fmt(memory::format::any), -#endif +#endif // !ENABLE_MKLDNN_V1 src_mem(nullptr), filter_mem(nullptr), bias_mem(nullptr), @@ -251,57 +303,28 @@ class MklConvFwdPrimitive : public MklPrimitive { void Setup(const MklConvFwdParams& convFwdDims) { // Create memory descriptors for convolution data w/ no specified format context_.src_md.reset(new memory::desc( -#ifdef ENABLE_MKLDNN_V1 - {convFwdDims.src_dims}, MklDnnType(), memory::format_tag::any)); -#else - {convFwdDims.src_dims}, MklDnnType(), memory::format::any)); -#endif + {convFwdDims.src_dims}, MklDnnType(), MEMORY_FORMAT::any)); context_.filter_md.reset(new memory::desc( -#ifdef ENABLE_MKLDNN_V1 - {convFwdDims.filter_dims}, MklDnnType(), - memory::format_tag::any)); -#else - {convFwdDims.filter_dims}, MklDnnType(), memory::format::any)); -#endif + {convFwdDims.filter_dims}, MklDnnType(), MEMORY_FORMAT::any)); context_.dst_md.reset(new memory::desc( -#ifdef ENABLE_MKLDNN_V1 - {convFwdDims.dst_dims}, MklDnnType(), - memory::format_tag::any)); -#else - {convFwdDims.dst_dims}, MklDnnType(), memory::format::any)); -#endif + {convFwdDims.dst_dims}, MklDnnType(), MEMORY_FORMAT::any)); if (!convFwdDims.bias_dims.empty()) context_.bias_md.reset(new memory::desc( -#ifdef ENABLE_MKLDNN_V1 - {convFwdDims.bias_dims}, MklDnnType(), - memory::format_tag::any)); -#else - {convFwdDims.bias_dims}, MklDnnType(), memory::format::any)); -#endif + {convFwdDims.bias_dims}, MklDnnType(), MEMORY_FORMAT::any)); - // Create a convolution + // Create a convolution descriptor if (!convFwdDims.bias_dims.empty()) { context_.fwd_desc.reset(new convolution_forward::desc( -#ifdef ENABLE_MKLDNN_V1 - prop_kind::forward, mkldnn::algorithm::convolution_direct, - *context_.src_md, -#else - prop_kind::forward, convolution_direct, *context_.src_md, -#endif + prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, *context_.filter_md, *context_.bias_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, padding_kind::zero)); } else { context_.fwd_desc.reset(new convolution_forward::desc( -#ifdef ENABLE_MKLDNN_V1 - prop_kind::forward, mkldnn::algorithm::convolution_direct, - *context_.src_md, -#else - prop_kind::forward, convolution_direct, *context_.src_md, -#endif + prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, *context_.filter_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, convFwdDims.padding_right, padding_kind::zero)); @@ -325,7 +348,7 @@ class MklConvFwdPrimitive : public MklPrimitive { op_alpha, #else post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha, -#endif +#endif // ENABLE_MKLDNN_V1 op_beta); } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); @@ -357,31 +380,21 @@ class MklConvFwdPrimitive : public MklPrimitive { context_.filter_fmt = static_cast( context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); -#endif +#endif // !ENABLE_MKLDNN_V1 -#ifdef ENABLE_MKLDNN_V1 // Create memory primitive based on dummy data - context_.src_mem.reset( - new memory(context_.fwd_pd.get()->src_desc(), cpu_engine_, DummyData)); - context_.filter_mem.reset(new memory(context_.fwd_pd.get()->weights_desc(), - cpu_engine_, DummyData)); - context_.dst_mem.reset( - new memory(context_.fwd_pd.get()->dst_desc(), cpu_engine_, DummyData)); -#else - context_.src_mem.reset( - new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); - context_.filter_mem.reset( - new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); - context_.dst_mem.reset( - new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); -#endif + context_.src_mem.reset(new MEMORY_CONSTRUCTOR( + context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData)); + context_.filter_mem.reset(new MEMORY_CONSTRUCTOR( + context_.fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_, DummyData)); + context_.dst_mem.reset(new MEMORY_CONSTRUCTOR( + context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData)); -#ifdef ENABLE_MKLDNN_V1 // Create convolution primitive and add it to net if (!convFwdDims.bias_dims.empty()) { - context_.bias_mem.reset(new memory( - {{convFwdDims.bias_dims}, MklDnnType(), memory::format_tag::x}, - cpu_engine_, DummyData)); + context_.bias_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD( + convFwdDims.bias_dims, Tbias, x, cpu_engine_, DummyData)); +#ifdef ENABLE_MKLDNN_V1 context_.conv_fwd.reset(new convolution_forward(*context_.fwd_pd)); context_.fwd_primitives_args.push_back( {{MKLDNN_ARG_SRC, *context_.src_mem}, @@ -397,15 +410,7 @@ class MklConvFwdPrimitive : public MklPrimitive { { MKLDNN_ARG_DST, *context_.dst_mem }}); } - context_.fwd_primitives.push_back(*context_.conv_fwd); - return; #else - // Create convolution primitive and add it to net - if (!convFwdDims.bias_dims.empty()) { - context_.bias_mem.reset(new memory( - {{{convFwdDims.bias_dims}, MklDnnType(), memory::format::x}, - cpu_engine_}, - DummyData)); context_.conv_fwd.reset(new convolution_forward( *context_.fwd_pd, *context_.src_mem, *context_.filter_mem, *context_.bias_mem, *context_.dst_mem)); @@ -414,10 +419,8 @@ class MklConvFwdPrimitive : public MklPrimitive { new convolution_forward(*context_.fwd_pd, *context_.src_mem, *context_.filter_mem, *context_.dst_mem)); } - +#endif // ENABLE_MKLDNN_V1 context_.fwd_primitives.push_back(*context_.conv_fwd); - return; -#endif } struct ConvFwdContext context_; @@ -683,7 +686,7 @@ class MklConvOp : public OpKernel { auto mkl_fmt_tag = MklTensorFormatToMklDnnDataFormat(tf_fmt); // NOTE: `mkl_fmt_tag` will be `format_tag::undef` for ReLU DCHECK_NE(mkl_fmt_tag, memory::format_tag::undef); -#endif +#endif // ENABLE_MKLDNN_V1 // If input is in MKL layout, then simply grab the layout; otherwise, // construct TF layout for input. @@ -699,21 +702,15 @@ class MklConvOp : public OpKernel { : memory::desc(src_dims, MklDnnType(), mkl_fmt_tag); #else : memory::desc(src_dims, MklDnnType(), tf_fmt); -#endif +#endif // ENABLE_MKLDNN_V1 src.SetUsrMem(src_md, &src_tensor); -#ifdef ENABLE_MKLDNN_V1 // Although filter shape (filter_dims) required is in MKL-DNN order, // the layout is Tensorflow's layout (HWIO) and (HWIGO) for // depthwise/group convolutions. - auto filter_format = is_conv2d ? (is_depthwise ? memory::format_tag::hwigo - : memory::format_tag::hwio) - : memory::format_tag::dhwio; -#else - auto filter_format = is_conv2d ? (is_depthwise ? memory::format::hwigo - : memory::format::hwio) - : memory::format::dhwio; -#endif + auto filter_format = is_conv2d ? (is_depthwise ? MEMORY_FORMAT::hwigo + : MEMORY_FORMAT::hwio) + : MEMORY_FORMAT::dhwio; DCHECK(!filter_mkl_shape.IsMklTensor()); auto filter_md = @@ -722,7 +719,7 @@ class MklConvOp : public OpKernel { : memory::desc(filter_dims, MklDnnType(), filter_format); filter.SetUsrMem(filter_md, &filter_tensor); - // MKLDNN dilations start from 0. + // MKL-DNN dilations start from 0. for (int i = 0; i < dilations.size(); ++i) --dilations[i]; // In some cases, primitive descriptor could potentially contain @@ -772,87 +769,48 @@ class MklConvOp : public OpKernel { // Check whether src and filter need to be reordered Tinput* src_data = nullptr; + if (IS_SRC_REORDER_NEEDED(src_md, conv_fwd_pd, conv_fwd)) { + // Reorder src + src.SetUsrMem(src_md, &src_tensor); + src.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + GET_SRC_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_)); + src_data = static_cast(src.GetOpMem().get_data_handle()); + } else { + src_data = static_cast( + const_cast(src_tensor.flat().data())); + } + + Tfilter* filter_data = nullptr; + if (IS_FILTER_REORDER_NEEDED(filter_md, conv_fwd_pd, conv_fwd)) { + bool is_filter_cached = false; + // If filter is a constant, we can avoid the conversion of filter from + // Tensorflow format to MKL format by caching the filter when it is + // converted for the first time. This cached filter can then be reused + // in subsequent iterations. + if (is_filter_const_) { + if (IsFilterCacheEmpty(context)) { + // Cache filter if it is not already cached. + CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, #ifdef ENABLE_MKLDNN_V1 - if (src_md != conv_fwd_pd->src_desc()) { - // Reorder src - src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem(conv_fwd_pd->src_desc(), cpu_engine_); - src_data = static_cast(src.GetOpMem().get_data_handle()); - } else { - src_data = static_cast( - const_cast(src_tensor.flat().data())); - } - - Tfilter* filter_data = nullptr; - if (filter_md != conv_fwd_pd->weights_desc()) { - bool is_filter_cached = false; - // If filter is a constant, we can avoid the conversion of filter from - // Tensorflow format to MKL format by caching the filter when it is - // converted for the first time. This cached filter can then be reused - // in subsequent iterations. - if (is_filter_const_) { - if (IsFilterCacheEmpty(context)) { - // Cache filter if it is not already cached. - CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, filter, filter_md, filter_mkl_shape); - } - filter_data = GetCachedFilter(context, conv_fwd_pd->weights_desc()); - is_filter_cached = (filter_data != nullptr); - } - if (!is_filter_cached) { - filter.SetUsrMem(filter_md, &filter_tensor); - if (filter_out_tensor == nullptr) { - filter.CheckReorderToOpMem(conv_fwd_pd->weights_desc(), - cpu_engine_); - } else { - filter.CheckReorderToOpMem( - conv_fwd_pd->weights_desc(), - filter.GetTensorBuffer(filter_out_tensor), cpu_engine_); - } - filter_data = - static_cast(filter.GetOpMem().get_data_handle()); - } - } else { - filter_data = static_cast( - const_cast(filter_tensor.flat().data())); - } #else - if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) { - // Reorder src - src.SetUsrMem(src_md, &src_tensor); - src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc()); - src_data = static_cast(src.GetOpMem().get_data_handle()); - } else { - src_data = static_cast( - const_cast(src_tensor.flat().data())); - } - - Tfilter* filter_data = nullptr; - if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) { - bool is_filter_cached = false; - // If filter is a constant, we can avoid the conversion of filter from - // Tensorflow format to MKL format by caching the filter when it is - // converted for the first time. This cached filter can then be reused - // in subsequent iterations. - if (is_filter_const_) { - if (IsFilterCacheEmpty(context)) { - // Cache filter if it is not already cached. - CacheFilter(context, conv_fwd_pd, filter_data, filter_tensor, filter, filter_md); +#endif // ENABLE_MKLDNN_V1 } - filter_data = - GetCachedFilter(context, conv_fwd->GetFilterMemoryFormat()); + filter_data = GetCachedFilter( + context, GET_WEIGHTS_FORMAT_FROM_OP_PD(conv_fwd_pd, conv_fwd)); is_filter_cached = (filter_data != nullptr); } if (!is_filter_cached) { filter.SetUsrMem(filter_md, &filter_tensor); if (filter_out_tensor == nullptr) { - filter.CheckReorderToOpMem( - conv_fwd_pd.get()->weights_primitive_desc()); + filter.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), cpu_engine_)); } else { filter.CheckReorderToOpMem( - conv_fwd_pd.get()->weights_primitive_desc(), - filter.GetTensorBuffer(filter_out_tensor)); + GET_WEIGHTS_DESC_FROM_OP_PD(conv_fwd_pd), + DATA_WITH_ENGINE(filter.GetTensorBuffer(filter_out_tensor), + cpu_engine_)); } filter_data = static_cast(filter.GetOpMem().get_data_handle()); @@ -861,7 +819,6 @@ class MklConvOp : public OpKernel { filter_data = static_cast( const_cast(filter_tensor.flat().data())); } -#endif // Execute convolution if (fuse_biasadd_) { @@ -962,7 +919,7 @@ class MklConvOp : public OpKernel { // NOTE: Fusion of BiasAdd is handled directly inside MklConvOp by // checking `fuse_biasadd_` flag. if (fuse_add_) { - params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}}); + params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}}); } if (fuse_activation_) { params.post_op_params.push_back( @@ -980,61 +937,38 @@ class MklConvOp : public OpKernel { return nullptr; } + virtual void AllocateOutputTensor(OpKernelContext* context, + const ConvFwdPd& conv_prim_desc, + const memory::dims& output_dims_mkl_order, + MKL_TENSOR_FORMAT output_tf_format, + Tensor** output_tensor) { + DCHECK(output_tensor); #ifdef ENABLE_MKLDNN_V1 - virtual void AllocateOutputTensor(OpKernelContext* context, - const ConvFwdPd& conv_prim_desc, - const memory::dims& output_dims_mkl_order, - MklTensorFormat output_tf_format, - Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); auto dst_md = conv_prim_desc.dst_desc(); - - if (!std::is_same::value) { - dst_md.data.data_type = - static_cast(MklDnnType()); - } - // Allocate shape of Mkl tensor. - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_md); - output_mkl_shape.SetElemType(MklDnnType()); - output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, output_tf_format); - - // Allocate shape of TF tensor. - TensorShape output_tf_shape; - output_tf_shape.AddDim((dst_md.get_size() / sizeof(Toutput))); - - AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, - output_tf_shape, output_mkl_shape); - } #else - virtual void AllocateOutputTensor(OpKernelContext* context, - const ConvFwdPd& conv_prim_desc, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, - Tensor** output_tensor) { - CHECK_NOTNULL(output_tensor); auto dst_pd = conv_prim_desc.dst_primitive_desc(); - auto dst_md = dst_pd.desc(); +#endif // ENABLE_MKLDNN_V1 + if (!std::is_same::value) { dst_md.data.data_type = static_cast(MklDnnType()); +#ifndef ENABLE_MKLDNN_V1 dst_pd = memory::primitive_desc(dst_md, cpu_engine_); +#endif // !ENABLE_MKLDNN_V1 } - // Allocate shape of Mkl tensor. + + // Allocate shape of MKL tensor MklDnnShape output_mkl_shape; output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetMklLayout(&DST_MD); output_mkl_shape.SetElemType(MklDnnType()); output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), output_dims_mkl_order, output_tf_format); - // Allocate shape of TF tensor. + // Allocate shape of TF tensor TensorShape output_tf_shape; - output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); - + output_tf_shape.AddDim((DST_MD.get_size() / sizeof(Toutput))); AllocateOutputSetMklShape(context, kOutputIndex_Dst, output_tensor, output_tf_shape, output_mkl_shape); if (fuse_add_) { @@ -1042,37 +976,40 @@ class MklConvOp : public OpKernel { MklDnnShape add_mkl_shape; GetMklShape(context, kInputIndex_Add, &add_mkl_shape); - // Check if need reorder + // Check if reorder is needed if (add_mkl_shape == output_mkl_shape) { - CHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape)); + DCHECK((*output_tensor)->CopyFrom(add_tensor, output_tf_shape)); } else { - auto add_md = - add_mkl_shape.IsMklTensor() - ? add_mkl_shape.GetMklLayout() - : memory::desc(output_dims_mkl_order, MklDnnType(), - output_mkl_shape.GetTfDataFormat()); - auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_); - void* add_buf = static_cast( - const_cast(add_tensor.flat().data())); - void* dst_buf = - static_cast((*output_tensor)->flat().data()); - auto add = new memory(add_pd, add_buf); - auto dst = new memory(dst_pd, dst_buf); - auto reorder_desc = mkldnn::reorder::primitive_desc(add_pd, dst_pd); - - std::vector net; - net.push_back(mkldnn::reorder(reorder_desc, *add, *dst)); - stream(stream::kind::eager).submit(net).wait(); + if (add_mkl_shape.IsMklTensor()) { + auto add_md = add_mkl_shape.GetMklLayout(); + } else { +#ifdef ENABLE_MKLDNN_V1 + auto output_format_tag = MklTensorFormatToMklDnnDataFormat( + output_mkl_shape.GetTfDataFormat()); + DCHECK_NE(output_format_tag, memory::format_tag::undef); + auto add_md = memory::desc(output_dims_mkl_order, + MklDnnType(), output_format_tag); +#else + auto add_md = + memory::desc(output_dims_mkl_order, MklDnnType(), + output_mkl_shape.GetTfDataFormat()); + auto add_pd = memory::primitive_desc(add_md, this->cpu_engine_); +#endif // ENABLE_MKLDNN_V1 + void* add_buf = static_cast( + const_cast(add_tensor.flat().data())); + void* dst_buf = + static_cast((*output_tensor)->flat().data()); + auto add = new MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf); + auto dst = new MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf); + auto reorder_desc = + REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_); + CreateAndExecuteReorder(reorder_desc, *add, *dst, this->cpu_engine_); + } } } } -#endif -#ifdef ENABLE_MKLDNN_V1 - engine cpu_engine_ = engine(engine::kind::cpu, 0); -#else - engine cpu_engine_ = engine(engine::cpu, 0); -#endif + engine cpu_engine_ = engine(ENGINE_CPU, 0); private: std::vector strides_; @@ -1092,7 +1029,7 @@ class MklConvOp : public OpKernel { bool fuse_add_ = false; float relu_up_bound_ = 0.0; - mkldnn::algorithm activation_alg_ = mkldnn::algorithm_undef; + mkldnn::algorithm activation_alg_ = ALGORITHM_UNDEF; int input_index_pad_ = 2; @@ -1101,40 +1038,62 @@ class MklConvOp : public OpKernel { const int kOutputIndex_Dst = 0, kOutputIndex_Filter = 1; const int kDilationH = 0, kDilationW = 1; + MKL_TENSOR_FORMAT_IN_C GetFilterTfDataFormat( + const MklDnnShape* filter_mkl_shape, + const ConvFwdPd& conv_prim_desc) const { #ifdef ENABLE_MKLDNN_V1 + DCHECK(filter_mkl_shape); + return filter_mkl_shape->GetTfDataFormat(); +#else + return conv_prim_desc.weights_primitive_desc().desc().data.format; +#endif // ENABLE_MKLDNN_V1 + } + // Allocate persistent tensors for cached filter data and // cached filter memory descriptor (data format) void AllocatePersistentTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, Tensor** filter_tensor, - const MklDnnShape& filter_mkl_shape) { + const MklDnnShape* filter_mkl_shape) { DCHECK(filter_tensor); TensorShape filter_tf_shape; filter_tf_shape.AddDim( - (conv_prim_desc.weights_desc().get_size() / sizeof(Tfilter))); + (conv_prim_desc.PRIMITIVE_DESC_WEIGHTS.get_size() / sizeof(Tfilter))); OP_REQUIRES_OK(context, context->allocate_persistent( DataTypeToEnum::value, filter_tf_shape, &cached_filter_data_ptensor_, filter_tensor)); Tensor* second_tensor = nullptr; TensorShape filter_mkl_format; - filter_mkl_format.AddDim(sizeof(filter_mkl_shape.GetTfDataFormat()) / - sizeof(DT_INT32)); + filter_mkl_format.AddDim( + sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) / + sizeof(DT_INT32)); OP_REQUIRES_OK(context, context->allocate_persistent( DT_INT32, filter_mkl_format, &cached_filter_md_ptensor_, &second_tensor)); - second_tensor->scalar()() = - static_cast(filter_mkl_shape.GetTfDataFormat()); + second_tensor->scalar()() = +#ifdef ENABLE_MKLDNN_V1 + static_cast( + GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)); +#else + GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc); +#endif // ENABLE_MKLDNN_V1 + } + + void AllocatePersistentTensor(OpKernelContext* context, + const ConvFwdPd& conv_prim_desc, + Tensor** filter_tensor) { + AllocatePersistentTensor(context, conv_prim_desc, filter_tensor, nullptr); } void AllocateFilterOutputTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, const memory::dims& filter_dims_tf_order, Tensor** filter_tensor) { - CHECK_NOTNULL(filter_tensor); - auto filter_md = conv_prim_desc.weights_desc(); + DCHECK(filter_tensor); + auto filter_md = conv_prim_desc.PRIMITIVE_DESC_WEIGHTS; - // Allocate shape of Mkl tensor. + // Allocate shape of MKL tensor MklDnnShape filter_mkl_shape; filter_mkl_shape.SetMklTensor(true); filter_mkl_shape.SetMklLayout(&filter_md); @@ -1145,7 +1104,7 @@ class MklConvOp : public OpKernel { // is stored in the MKL data. filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), filter_dims_tf_order, - MklTensorFormat::FORMAT_UNDEF); + MKL_TENSOR_FORMAT_BLOCKED); // Allocate the data space for the filter to propagate as TF tensor. TensorShape filter_tf_shape; @@ -1162,20 +1121,22 @@ class MklConvOp : public OpKernel { MklDnnData* bias, MklDnnData* output, Tensor* filter_out_tensor) { - CHECK_NOTNULL(filter_out_tensor); + DCHECK(filter_out_tensor); // Create reorders between user layout and MKL layout if it is needed and // add it to the net before convolution. No need to check for output // reorder as we propagate output layout to the next layer. - src->CheckReorderToOpMem(conv_prim_desc.src_desc(), cpu_engine_); + src->CheckReorderToOpMem( + MEMORY_PD_WITHOUT_DATA(conv_prim_desc.PRIMITIVE_DESC_SRC, cpu_engine_)); - // rather than re-order to a temp buffer, reorder directly to the + // Rather than re-ordering to a temp buffer, reorder directly to the // filter output tensor - filter->CheckReorderToOpMem(conv_prim_desc.weights_desc(), + filter->CheckReorderToOpMem(conv_prim_desc.PRIMITIVE_DESC_WEIGHTS, filter->GetTensorBuffer(filter_out_tensor)); // Create convolution primitive and add it to net. std::vector net; +#ifdef ENABLE_MKLDNN_V1 std::vector> net_args; if (bias) { DCHECK(fuse_biasadd_); @@ -1194,85 +1155,12 @@ class MklConvOp : public OpKernel { output->GetOpMem() }}); } stream cpu_stream(cpu_engine_); - DCHECK_EQ(net.size(), net_args.size()); for (size_t i = 0; i < net.size(); ++i) { net.at(i).execute(cpu_stream, net_args.at(i)); } cpu_stream.wait(); - } #else - void AllocatePersistentTensor(OpKernelContext* context, - const ConvFwdPd& conv_prim_desc, - Tensor** filter_tensor) { - DCHECK(filter_tensor); - TensorShape filter_tf_shape; - filter_tf_shape.AddDim( - (conv_prim_desc.weights_primitive_desc().get_size() / sizeof(Tfilter))); - OP_REQUIRES_OK(context, context->allocate_persistent( - DataTypeToEnum::value, filter_tf_shape, - &cached_filter_data_ptensor_, filter_tensor)); - - Tensor* second_tensor = nullptr; - TensorShape filter_mkl_format; - filter_mkl_format.AddDim( - sizeof(conv_prim_desc.weights_primitive_desc().desc().data.format) / - sizeof(DT_INT32)); - OP_REQUIRES_OK(context, context->allocate_persistent( - DT_INT32, filter_mkl_format, - &cached_filter_md_ptensor_, &second_tensor)); - second_tensor->scalar()() = - conv_prim_desc.weights_primitive_desc().desc().data.format; - } - - void AllocateFilterOutputTensor(OpKernelContext* context, - const ConvFwdPd& conv_prim_desc, - const memory::dims& filter_dims_tf_order, - Tensor** filter_tensor) { - CHECK_NOTNULL(filter_tensor); - auto filter_pd = conv_prim_desc.weights_primitive_desc(); - - // Allocate shape of Mkl tensor. - MklDnnShape filter_mkl_shape; - filter_mkl_shape.SetMklTensor(true); - filter_mkl_shape.SetMklLayout(&filter_pd); - filter_mkl_shape.SetElemType(MklDnnType()); - - // The format of the filter is actually OIhw8i8o, but TF doesn't support - // this format. Just use format::blocked for now because the layout - // is stored in the MKL data. - filter_mkl_shape.SetTfLayout(filter_dims_tf_order.size(), - filter_dims_tf_order, memory::format::blocked); - - // Allocate the data space for the filter to propagate as TF tensor. - TensorShape filter_tf_shape; - filter_tf_shape.AddDim((filter_pd.get_size() / sizeof(Tfilter))); - - AllocateOutputSetMklShape(context, kOutputIndex_Filter, filter_tensor, - filter_tf_shape, filter_mkl_shape); - } - - // Prepare and execute net - checks for input and output reorders. - void PrepareAndExecuteNet(const ConvFwdPd& conv_prim_desc, - MklDnnData* src, - MklDnnData* filter, - MklDnnData* bias, - MklDnnData* output, - Tensor* filter_out_tensor) { - CHECK_NOTNULL(filter_out_tensor); - - // Create reorders between user layout and MKL layout if it is needed and - // add it to the net before convolution. No need to check for output - // reorder as we propagate output layout to the next layer. - src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc()); - - // rather than re-order to a temp buffer, reorder directly to the - // filter output tensor - filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), - filter->GetTensorBuffer(filter_out_tensor)); - - // Create convolution primitive and add it to net. - std::vector net; if (bias) { DCHECK(fuse_biasadd_); net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(), @@ -1284,10 +1172,9 @@ class MklConvOp : public OpKernel { filter->GetOpMem(), output->GetOpMem())); } - stream(stream::kind::eager).submit(net).wait(); +#endif // ENABLE_MKLDNN_V1 } -#endif // LOCKS_EXCLUDED annotation ensures that the lock (mu_) cannot // be acquired before entering the function, since it is acquired @@ -1299,9 +1186,9 @@ class MklConvOp : public OpKernel { return (cached_filter_data_tensor.NumElements() == 0); } +// Cache the converted filter in a persistent tensor. +// Only one thread can execute this method at any given time. #ifdef ENABLE_MKLDNN_V1 - // Cache the converted filter in a persistent tensor. - // Only one thread can execute this method at any given time. void CacheFilter(OpKernelContext* context, const std::shared_ptr& conv_fwd_pd, Tfilter* filter_data, const Tensor& filter_tensor, @@ -1324,14 +1211,30 @@ class MklConvOp : public OpKernel { Tensor* filter_tensor_ptr = nullptr; AllocatePersistentTensor(context, *conv_fwd_pd, &filter_tensor_ptr, - filter_mkl_shape); + &filter_mkl_shape); void* cached_filter_data = filter.GetTensorBuffer(filter_tensor_ptr); size_t cached_filter_data_size = filter.GetOpMem().get_desc().get_size(); memcpy(cached_filter_data, filter_data, cached_filter_data_size); } + + bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, + const Tensor& cached_filter_md) { + auto filter_md_data = filter_md.data; + const char* filter_data = reinterpret_cast(&filter_md_data); + + auto cached_filter_md_data = cached_filter_md.scalar()(); + const char* cached_filter_data = + reinterpret_cast(&cached_filter_md_data); + + for (size_t i = 0; i < sizeof(filter_md_data); ++i) { + if (*filter_data++ != *cached_filter_data++) { + return false; + } + } + return true; + } + #else - // Cache the converted filter in a persistent tensor. - // Only one thread can execute this method at any given time. void CacheFilter(OpKernelContext* context, const std::shared_ptr& conv_fwd_pd, Tfilter* filter_data, const Tensor& filter_tensor, @@ -1358,66 +1261,31 @@ class MklConvOp : public OpKernel { filter.GetOpMem().get_primitive_desc().get_size(); memcpy(cached_filter_data, filter_data, cached_filter_data_size); } -#endif - -#ifdef ENABLE_MKLDNN_V1 - bool AreMemoryDescriptorsEqual(const memory::desc& filter_md, - const Tensor& cached_filter_md) { - auto filter_md_data = filter_md.data; - const char* filter_data = reinterpret_cast(&filter_md_data); - - auto cached_filter_md_data = cached_filter_md.scalar()(); - const char* cached_filter_data = - reinterpret_cast(&cached_filter_md_data); - - for (size_t i = 0; i < sizeof(filter_md_data); ++i) { - if (*filter_data++ != *cached_filter_data++) { - return false; - } - } - return true; - } +#endif // ENABLE_MKLDNN_V1 Tfilter* GetCachedFilter(OpKernelContext* context, - const memory::desc& filter_md) LOCKS_EXCLUDED(mu_) { + const MEMORY_DESC& filter_md) LOCKS_EXCLUDED(mu_) { tf_shared_lock lock(mu_); const Tensor& cached_filter_data = *cached_filter_data_ptensor_.AccessTensor(context); const Tensor& cached_filter_md = *cached_filter_md_ptensor_.AccessTensor(context); - // Check if the memory descriptor of the cached weights is same as - // filter_mf. If so, we can used the cached weights; otherwise - // return NULL. +// Check if the memory descriptor of the cached weights is same as +// filter_md. If so, we can used the cached weights; otherwise +// return NULL. +#ifdef ENABLE_MKLDNN_V1 if (cached_filter_md.scalar().size() && AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) { - return static_cast( - const_cast(cached_filter_data.flat().data())); - } - return nullptr; - } #else - Tfilter* GetCachedFilter(OpKernelContext* context, - const memory::format& filter_mf) - LOCKS_EXCLUDED(mu_) { - tf_shared_lock lock(mu_); - const Tensor& cached_filter_data = - *cached_filter_data_ptensor_.AccessTensor(context); - const Tensor& cached_filter_md = - *cached_filter_md_ptensor_.AccessTensor(context); - - // Check if the memory descriptor of the cached weights is same as - // filter_mf. If so, we can used the cached weights; otherwise - // return NULL. - // TODO (bhavanis): Do we need to cast filter_mf before the check? if (cached_filter_md.scalar().size() && - cached_filter_md.scalar()() == filter_mf) { + cached_filter_md.scalar()() == filter_md) { +#endif // ENABLE_MKLDNN_V1 return static_cast( const_cast(cached_filter_data.flat().data())); } return nullptr; } -#endif }; // Base class for fused convolution forward operations @@ -1448,26 +1316,26 @@ class MklFusedConvOp errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"Relu"}) { - this->set_fuse_activation(true, mkldnn::eltwise_relu); + this->set_fuse_activation(true, ALGORITHM::eltwise_relu); } else if (fused_ops == std::vector{"Relu6"}) { - this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0); + this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); } else if (fused_ops == std::vector{"Elu"}) { - this->set_fuse_activation(true, mkldnn::eltwise_elu); + this->set_fuse_activation(true, ALGORITHM::eltwise_elu); } else if (fused_ops == std::vector{"BiasAdd", "Relu"}) { this->set_fuse_biasadd(true); - this->set_fuse_activation(true, mkldnn::eltwise_relu); + this->set_fuse_activation(true, ALGORITHM::eltwise_relu); OP_REQUIRES(context, num_args == 1, errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Relu6"}) { this->set_fuse_biasadd(true); - this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0); + this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); OP_REQUIRES(context, num_args == 1, errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); } else if (fused_ops == std::vector{"BiasAdd", "Elu"}) { this->set_fuse_biasadd(true); - this->set_fuse_activation(true, mkldnn::eltwise_elu); + this->set_fuse_activation(true, ALGORITHM::eltwise_elu); OP_REQUIRES(context, num_args == 1, errors::InvalidArgument( "Fused Conv2D must have one extra argument: bias.")); @@ -1481,7 +1349,7 @@ class MklFusedConvOp } else if (fused_ops == std::vector{"BiasAdd", "Add", "Relu"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); - this->set_fuse_activation(true, mkldnn::eltwise_relu); + this->set_fuse_activation(true, ALGORITHM::eltwise_relu); OP_REQUIRES( context, num_args == 2, errors::InvalidArgument( @@ -1489,7 +1357,7 @@ class MklFusedConvOp } else if (fused_ops == std::vector{"BiasAdd", "Add", "Relu6"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); - this->set_fuse_activation(true, mkldnn::eltwise_bounded_relu, 6.0); + this->set_fuse_activation(true, ALGORITHM::eltwise_bounded_relu, 6.0); OP_REQUIRES( context, num_args == 2, errors::InvalidArgument( @@ -1497,7 +1365,7 @@ class MklFusedConvOp } else if (fused_ops == std::vector{"BiasAdd", "Add", "Elu"}) { this->set_fuse_biasadd(true); this->set_fuse_add(true); - this->set_fuse_activation(true, mkldnn::eltwise_elu); + this->set_fuse_activation(true, ALGORITHM::eltwise_elu); OP_REQUIRES( context, num_args == 2, errors::InvalidArgument( @@ -1653,7 +1521,7 @@ class MklQuantizedConv2DOp (255.0f * 127.0f * output_range); } params.post_op_params.push_back( - {"output_scale", mkldnn::algorithm_undef, scales}); + {"output_scale", ALGORITHM_UNDEF, scales}); } } @@ -1672,10 +1540,6 @@ class MklQuantizedConv2DOp const float* min_filter = min_filter_vector.flat().data(); const float* max_filter = max_filter_vector.flat().data(); - std::vector net; -#ifdef ENABLE_MKLDNN_V1 - std::vector> net_args; -#endif if (bias_enabled) { if (std::is_same::value) { return static_cast( @@ -1697,50 +1561,22 @@ class MklQuantizedConv2DOp } else { bias_attr.set_output_scales(1, scales); } -#ifdef ENABLE_MKLDNN_V1 + auto bias_md = - memory::desc({static_cast(bias_tensor.NumElements())}, - MklDnnType(), memory::format_tag::x); - + MEMORY_PD_CONSTRUCTOR(static_cast(bias_tensor.NumElements()), + Tbias, x, this->cpu_engine_); void* bias_buf = static_cast( const_cast(bias_tensor.flat().data())); - input_bias_ = new memory(bias_md, this->cpu_engine_, bias_buf); - scaled_bias_ = new memory(conv_fwd_pd->bias_desc(), this->cpu_engine_); - auto reorder_desc = mkldnn::reorder::primitive_desc( - this->cpu_engine_, input_bias_->get_desc(), this->cpu_engine_, - scaled_bias_->get_desc(), bias_attr); - net.push_back(mkldnn::reorder(reorder_desc)); - net_args.push_back({{MKLDNN_ARG_FROM, *input_bias_}, - { MKLDNN_ARG_TO, - *scaled_bias_ }}); - - DCHECK_EQ(net.size(), net_args.size()); - - stream cpu_stream(this->cpu_engine_); - for (size_t i = 0; i < net.size(); ++i) { - net.at(i).execute(cpu_stream, net_args.at(i)); - } - cpu_stream.wait(); - - return reinterpret_cast(scaled_bias_->get_data_handle()); -#else - auto bias_pd = - memory::primitive_desc({{static_cast(bias_tensor.NumElements())}, - MklDnnType(), - memory::format::x}, - this->cpu_engine_); - - void* bias_buf = static_cast( - const_cast(bias_tensor.flat().data())); - input_bias_ = new memory(bias_pd, bias_buf); - scaled_bias_ = new memory(conv_fwd_pd->bias_primitive_desc()); - auto reorder_desc = mkldnn::reorder::primitive_desc( - input_bias_->get_primitive_desc(), scaled_bias_->get_primitive_desc(), + input_bias_ = + new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, bias_buf); + scaled_bias_ = new MEMORY_CONSTRUCTOR_WITHOUT_DATA( + conv_fwd_pd->PRIMITIVE_DESC_BIAS, this->cpu_engine_); + auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( + input_bias_->GET_DESC, scaled_bias_->GET_DESC, this->cpu_engine_, bias_attr); - net.push_back(mkldnn::reorder(reorder_desc, *input_bias_, *scaled_bias_)); - stream(stream::kind::eager).submit(net).wait(); + CreateAndExecuteReorder(reorder_desc, *input_bias_, *scaled_bias_, + this->cpu_engine_); return reinterpret_cast(scaled_bias_->get_data_handle()); -#endif } else { return nullptr; } @@ -1768,7 +1604,7 @@ class MklQuantizedConv2DReluOp MklQuantizedConv2DOp::ExtendConvFwdParams(context, params); params.post_op_params.push_back( - {"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}}); + {"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}}); } }; @@ -1825,27 +1661,23 @@ class MklQuantizedConv2DSumReluOp // If it is not then it is DT_INT8 and is scaled appropriately. if (summand_type == DT_QUINT8) params.post_op_params.push_back( - {"sum", mkldnn::algorithm_undef, {scale_summand / scale_output}}); + {"sum", ALGORITHM_UNDEF, {scale_summand / scale_output}}); else params.post_op_params.push_back( {"sum", - mkldnn::algorithm_undef, + ALGORITHM_UNDEF, {255.0f * scale_summand / (scale_output * 127.0f)}}); } else { - params.post_op_params.push_back({"sum", mkldnn::algorithm_undef, {1.0}}); + params.post_op_params.push_back({"sum", ALGORITHM_UNDEF, {1.0}}); } params.post_op_params.push_back( - {"activation", mkldnn::eltwise_relu, {1.0, 0.0, 0.0}}); + {"activation", ALGORITHM::eltwise_relu, {1.0, 0.0, 0.0}}); } void AllocateOutputTensor(OpKernelContext* context, const ConvFwdPd& conv_prim_desc, const memory::dims& output_dims_mkl_order, -#ifdef ENABLE_MKLDNN_V1 - MklTensorFormat output_tf_format, -#else - memory::format output_tf_format, -#endif + MKL_TENSOR_FORMAT output_tf_format, Tensor** output_tensor) override { int summand_idx = context->num_inputs() / 2 - 1; if (std::is_same::value) { @@ -1913,56 +1745,26 @@ class MklQuantizedConv2DSumReluOp } else { reorder_attr.set_output_scales(2, scales); } -#ifdef ENABLE_MKLDNN_V1 auto summand_md = summand_mkl_shape.IsMklTensor() ? summand_mkl_shape.GetMklLayout() : memory::desc(output_dims_mkl_order, MklDnnType(), - memory::format_tag::nhwc); - void* summand_buf = - static_cast(const_cast(summand.flat().data())); - void* dst_buf = - static_cast((*output_tensor)->flat().data()); - summand_ = new memory(summand_md, this->cpu_engine_, summand_buf); - dst_ = new memory(conv_prim_desc.dst_desc(), this->cpu_engine_, dst_buf); - auto reorder_desc = mkldnn::reorder::primitive_desc( - this->cpu_engine_, summand_md, this->cpu_engine_, - conv_prim_desc.dst_desc(), reorder_attr); - - std::vector net; - std::vector> net_args; - - net.push_back(mkldnn::reorder(reorder_desc)); - net_args.push_back({{MKLDNN_ARG_FROM, *summand_}, - { MKLDNN_ARG_TO, - *dst_ }}); - DCHECK_EQ(net.size(), net_args.size()); - - stream cpu_stream(this->cpu_engine_); - for (size_t i = 0; i < net.size(); ++i) { - net.at(i).execute(cpu_stream, net_args.at(i)); - } - cpu_stream.wait(); -#else - auto summand_md = - summand_mkl_shape.IsMklTensor() - ? summand_mkl_shape.GetMklLayout() - : memory::desc(output_dims_mkl_order, MklDnnType(), - memory::format::nhwc); + MEMORY_FORMAT::nhwc); +#ifndef ENABLE_MKLDNN_V1 auto summand_pd = memory::primitive_desc(summand_md, this->cpu_engine_); +#endif // !ENABLE_MKLDNN_V1 void* summand_buf = static_cast(const_cast(summand.flat().data())); void* dst_buf = static_cast((*output_tensor)->flat().data()); - summand_ = new memory(summand_pd, summand_buf); - dst_ = new memory(conv_prim_desc.dst_primitive_desc(), dst_buf); - auto reorder_desc = mkldnn::reorder::primitive_desc( - summand_pd, conv_prim_desc.dst_primitive_desc(), reorder_attr); - - std::vector net; - net.push_back(mkldnn::reorder(reorder_desc, *summand_, *dst_)); - stream(stream::kind::eager).submit(net).wait(); -#endif + summand_ = + new MEMORY_CONSTRUCTOR(SUMMAND_MD, this->cpu_engine_, summand_buf); + dst_ = new MEMORY_CONSTRUCTOR(conv_prim_desc.PRIMITIVE_DESC_DST, + this->cpu_engine_, dst_buf); + auto reorder_desc = REORDER_PD_CONSTRUCTOR_WITH_ATTR( + SUMMAND_MD, conv_prim_desc.PRIMITIVE_DESC_DST, this->cpu_engine_, + reorder_attr); + CreateAndExecuteReorder(reorder_desc, *summand_, *dst_, this->cpu_engine_); } memory* summand_ = nullptr; @@ -2416,5 +2218,36 @@ TF_CALL_bfloat16(REGISTER_MKL_CPU_2D_FUSED); TF_CALL_float(REGISTER_MKL_CPU_3D); TF_CALL_bfloat16(REGISTER_MKL_CPU_3D); +#undef ADD_MD +#undef ALGORITHM +#undef ALGORITHM_UNDEF +#undef CPU_STREAM +#undef DATA_WITH_ENGINE +#undef DST_MD +#undef ENGINE_CPU +#undef GET_DESC +#undef GET_MEMORY_DESC_CONSTRUCTOR +#undef GET_SRC_DESC_FROM_OP_PD +#undef GET_WEIGHTS_DESC_FROM_OP_PD +#undef GET_WEIGHTS_FORMAT_FROM_OP_PD +#undef IS_FILTER_REORDER_NEEDED +#undef IS_SRC_REORDER_NEEDED +#undef MEMORY_CONSTRUCTOR +#undef MEMORY_CONSTRUCTOR_USING_MEM_PD +#undef MEMORY_CONSTRUCTOR_WITHOUT_DATA +#undef MEMORY_DESC +#undef MEMORY_FORMAT +#undef MEMORY_PD_CONSTRUCTOR +#undef MEMORY_PD_WITHOUT_DATA +#undef MKL_TENSOR_FORMAT +#undef MKL_TENSOR_FORMAT_BLOCKED +#undef MKL_TENSOR_FORMAT_IN_C +#undef PRIMITIVE_DESC_BIAS +#undef PRIMITIVE_DESC_DST +#undef PRIMITIVE_DESC_SRC +#undef PRIMITIVE_DESC_WEIGHTS +#undef REORDER_PD_CONSTRUCTOR +#undef REORDER_PD_CONSTRUCTOR_WITH_ATTR +#undef SUMMAND_MD } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 9cd69e36f2e..56b70f4a433 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -687,9 +687,9 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape)); } } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); LOG(FATAL) << "Operation received an exception: " << error_msg; } return output_tensor; @@ -1194,6 +1194,27 @@ inline memory::desc CreateBlockedMemDescHelper(const memory::dims& dim, return memory::desc(md); } +inline void CreateAndExecuteReorder(const reorder::primitive_desc& reorder_desc, + const memory& src_mem, + const memory& dst_mem, + const engine& engine) { + std::vector net; +#ifdef ENABLE_MKLDNN_V1 + net.push_back(mkldnn::reorder(reorder_desc)); + std::vector net_args; + net_args.push_back({{MKLDNN_ARG_FROM, src_mem}, {MKLDNN_ARG_TO, dst_mem}}); + DCHECK_EQ(net.size(), net_args.size()); + stream cpu_stream(engine); + for (size_t i = 0; i < net.size(); ++i) { + net.at(i).execute(cpu_stream, net_args.at(i)); + } + cpu_stream.wait(); +#else + net.push_back(mkldnn::reorder(reorder_desc, src_mem, dst_mem)); + stream(stream::kind::eager).submit(net).wait(); +#endif // ENABLE_MKLDNN_V1 +} + template inline primitive FindOrCreateReorder(const memory* from, const memory* to); From cae3481620701db362228a680bf9e7fd866c3b61 Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Mon, 29 Jul 2019 15:31:48 -0700 Subject: [PATCH 5/6] Addressed some more review comments. --- tensorflow/core/kernels/mkl_conv_ops.cc | 10 +---- tensorflow/core/kernels/mkl_conv_ops.h | 59 ++++++++----------------- 2 files changed, 20 insertions(+), 49 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index b5b7d6bf4d7..320eabbefe4 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -48,7 +48,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_format.h" using mkldnn::convolution_forward; -using mkldnn::memory; using mkldnn::prop_kind; using mkldnn::stream; @@ -1071,13 +1070,8 @@ class MklConvOp : public OpKernel { OP_REQUIRES_OK(context, context->allocate_persistent( DT_INT32, filter_mkl_format, &cached_filter_md_ptensor_, &second_tensor)); - second_tensor->scalar()() = -#ifdef ENABLE_MKLDNN_V1 - static_cast( - GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)); -#else - GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc); -#endif // ENABLE_MKLDNN_V1 + second_tensor->scalar()() = static_cast( + GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)); } void AllocatePersistentTensor(OpKernelContext* context, diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index 30f2e528745..d6f1f2db96a 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -42,13 +42,19 @@ limitations under the License. #ifndef ENABLE_MKLDNN_V1 using mkldnn::convolution_direct; -#endif +#endif // !ENABLE_MKLDNN_V1 using mkldnn::convolution_forward; using mkldnn::prop_kind; using mkldnn::stream; namespace tensorflow { +#ifdef ENABLE_MKLDNN_V1 +#define MKLDNN_SIZE_DTYPE long int +#else +#define MKLDNN_SIZE_DTYPE int +#endif // ENABLE_MKLDNN_V1 + class MklDnnConvUtil { protected: OpKernelContext* context_; // We don't own this. @@ -138,13 +144,8 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_cols_raw, "Input cols too large"); int input_cols = static_cast(input_cols_raw); -#ifdef ENABLE_MKLDNN_V1 // MKL-DNN always requires input in NCHW format Conv2D. - std::vector mkldnn_sizes(4, -1); -#else - // MKL-DNN always requires input in NCHW format Conv2D. - std::vector mkldnn_sizes(4, -1); -#endif + std::vector mkldnn_sizes(4, -1); mkldnn_sizes[MklDnnDims::Dim_N] = input_batch; mkldnn_sizes[MklDnnDims::Dim_C] = input_depth; mkldnn_sizes[MklDnnDims::Dim_H] = input_rows; @@ -167,13 +168,8 @@ class MklDnnConvUtil { CHECK_BOUNDS(input_cols_raw, "Input cols too large"); int input_cols = static_cast(input_cols_raw); -#ifdef ENABLE_MKLDNN_V1 // MKL-DNN always requires input in NCDHW format for Conv3D. - std::vector mkldnn_sizes(5, -1); -#else - // MKL-DNN always requires input in NCDHW format for Conv3D. - std::vector mkldnn_sizes(5, -1); -#endif + std::vector mkldnn_sizes(5, -1); mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch; mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes; @@ -236,11 +232,7 @@ class MklDnnConvUtil { // GOIHW = (group, out_depth, in_depth, rows, cols) // Specifically for depthwise G=filter_indepth, O=filter_outdepth, I=1 if (is_depthwise) { -#ifdef ENABLE_MKLDNN_V1 - std::vector mkldnn_sizes(5, -1); -#else - std::vector mkldnn_sizes(5, -1); -#endif + std::vector mkldnn_sizes(5, -1); mkldnn_sizes[MKL_GROUP_FILTER_DIM_G] = filter_in_depth; mkldnn_sizes[MKL_GROUP_FILTER_DIM_O] = filter_out_depth; mkldnn_sizes[MKL_GROUP_FILTER_DIM_I] = 1; @@ -249,11 +241,7 @@ class MklDnnConvUtil { *filter_dims = mkldnn_sizes; } else { -#ifdef ENABLE_MKLDNN_V1 - std::vector mkldnn_sizes(4, -1); -#else - std::vector mkldnn_sizes(4, -1); -#endif + std::vector mkldnn_sizes(4, -1); mkldnn_sizes[MklDnnDims::Dim_O] = filter_out_depth; mkldnn_sizes[MklDnnDims::Dim_I] = filter_in_depth; mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows; @@ -279,15 +267,9 @@ class MklDnnConvUtil { int filter_out_depth = static_cast(filter_shape.dim_size(TF_3DFILTER_DIM_O)); -#ifdef ENABLE_MKLDNN_V1 // MKL-DNN always needs filter in OIDHW format. // OIDHW = (out_depth, in_depth, planes, rows, cols) - std::vector mkldnn_sizes(5, -1); -#else - // MKL-DNN always needs filter in OIDHW format. - // OIDHW = (out_depth, in_depth, planes, rows, cols) - std::vector mkldnn_sizes(5, -1); -#endif + std::vector mkldnn_sizes(5, -1); mkldnn_sizes[MklDnnDims3D::Dim3d_O] = filter_out_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_I] = filter_in_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes; @@ -479,24 +461,15 @@ class MklDnnConvUtil { *output_dims_tf_order = TFShapeToMklDnnDims(out_shape); if (is_conv2d) { -#ifdef ENABLE_MKLDNN_V1 // For Conv2D, MKL-DNN always needs output in NCHW format. - std::vector mkldnn_sizes(4, -1); -#else - // For Conv2D, MKL-DNN always needs output in NCHW format. - std::vector mkldnn_sizes(4, -1); -#endif + std::vector mkldnn_sizes(4, -1); mkldnn_sizes[MklDnnDims::Dim_N] = out_batch; mkldnn_sizes[MklDnnDims::Dim_C] = out_depth; mkldnn_sizes[MklDnnDims::Dim_H] = static_cast(out_rows); mkldnn_sizes[MklDnnDims::Dim_W] = static_cast(out_cols); *output_dims_mkl_order = mkldnn_sizes; } else { -#ifdef ENABLE_MKLDNN_V1 - std::vector mkldnn_sizes(5, -1); -#else - std::vector mkldnn_sizes(5, -1); -#endif + std::vector mkldnn_sizes(5, -1); mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch; mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth; mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast(out_planes); @@ -658,6 +631,10 @@ class MklDummyOp : public OpKernel { } }; +#ifdef ENABLE_MKLDNN_V1 +#undef MKLDNN_SIZE_DTYPE +#endif // ENABLE_MKLDNN_V1 + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_ From ab4af31785b014036516708c879a2d4ef60d1364 Mon Sep 17 00:00:00 2001 From: Bhavani Subramanian Date: Mon, 29 Jul 2019 16:04:02 -0700 Subject: [PATCH 6/6] Removed ifdef guard for MKLDNN_SIZE_DTYPE since it will always be defined. --- tensorflow/core/kernels/mkl_conv_ops.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h index d6f1f2db96a..4e4aaec9d72 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.h +++ b/tensorflow/core/kernels/mkl_conv_ops.h @@ -631,9 +631,7 @@ class MklDummyOp : public OpKernel { } }; -#ifdef ENABLE_MKLDNN_V1 #undef MKLDNN_SIZE_DTYPE -#endif // ENABLE_MKLDNN_V1 } // namespace tensorflow