diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index d5d1a7d6712..cdb841e06d7 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -365,7 +365,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.addn, mkl_op_registry::GetMklOpName(csinfo_.addn), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); -#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), CopyAttrsAll, RewriteIfAtleastOneMklInput, kRewriteForLayoutPropagation}); @@ -403,12 +402,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { {csinfo_.conjugate_transpose, mkl_op_registry::GetMklOpName(csinfo_.conjugate_transpose), CopyAttrsAll, AlwaysRewrite, kRewriteForOpNameChange}); -#endif // !ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.conv2d, mkl_op_registry::GetMklOpName(csinfo_.conv2d), CopyAttrsConvCheckConstFilter, AlwaysRewrite, kRewriteForLayoutPropagation}); -#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias, CopyAttrsConvCheckConstFilter, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -477,20 +474,16 @@ class MklLayoutRewritePass : public GraphOptimizationPass { {csinfo_.fused_batch_norm_grad_v3, mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm_grad_v3), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); -#endif // !ENABLE_MKLDNN_V1 - rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d, CopyAttrsFusedConv2D, FusedConv2DRewrite, kRewriteForLayoutPropagation}); rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul, CopyAttrsAllCheckConstFilter, FusedMatMulRewrite}); -#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsAll, RewriteIfAtleastOneMklInput, kRewriteForLayoutPropagation}); - rinfo_.push_back({csinfo_.lrn, mkl_op_registry::GetMklOpName(csinfo_.lrn), CopyAttrsAll, LrnRewrite, kRewriteForLayoutPropagation}); rinfo_.push_back( @@ -654,7 +647,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName(csinfo_.quantize_v2), CopyAttrsAll, QuantizeOpRewrite, kRewriteForLayoutPropagation}); -#endif // !ENABLE_MKLDNN_V1 rinfo_.push_back({csinfo_.relu, mkl_op_registry::GetMklOpName(csinfo_.relu), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -667,11 +659,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back( {csinfo_.relu6_grad, mkl_op_registry::GetMklOpName(csinfo_.relu6_grad), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); -#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back( {csinfo_.requantize, mkl_op_registry::GetMklOpName(csinfo_.requantize), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); -#endif // !ENABLE_MKLDNN_V1 // Disable these two MKL operators for now due to some test failures caused // by these two ops /* @@ -691,7 +681,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { mkl_op_registry::GetMklOpName(csinfo_.slice), CopyAttrsAll, RewriteIfAtleastOneMklInput, kRewriteForLayoutPropagation}); -#ifndef ENABLE_MKLDNN_V1 rinfo_.push_back( {csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -798,7 +787,6 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // CheckForMklOp FuseMaxPool3D, CopyAttrsPooling}); -#endif // !ENABLE_MKLDNN_V1 } // Standard interface to run pass diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc index 7a993820ad2..e36e481ebbf 100644 --- a/tensorflow/core/kernels/mkl_avgpooling_op.cc +++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc @@ -29,7 +29,9 @@ using mkldnn::algorithm; using mkldnn::engine; using mkldnn::error; using mkldnn::memory; +#ifndef ENABLE_MKLDNN_V1 using mkldnn::padding_kind; +#endif using mkldnn::pooling_backward; using mkldnn::pooling_forward; using mkldnn::prop_kind; @@ -108,11 +110,18 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind = prop_kind::forward_inference; else pooling_prop_kind = prop_kind::forward_training; +#ifdef ENABLE_MKLDNN_V1 MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, - pooling_prop_kind, static_cast(input_md.data.format)); - + pooling_prop_kind, + static_cast(this->data_format_mkldnn_)); +#else + MklPoolingParams fwdParams( + src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, + padding_right, ALGORITHM::pooling_avg_exclude_padding, + pooling_prop_kind, static_cast(input_md.data.format)); +#endif pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); // Allocate output tensor. @@ -224,11 +233,19 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase { // Pass prop_kind::forward_training to create a forward primitive // that is used in the backward pass. +#ifdef ENABLE_MKLDNN_V1 MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training, - static_cast(src_md.data.format)); + static_cast(this->data_format_mkldnn_)); +#else + MklPoolingParams bwdParams( + orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, + ALGORITHM::pooling_avg_exclude_padding, prop_kind::forward_training, + static_cast(src_md.data.format)); +#endif MklPoolingBwdPrimitive* pooling_bwd = MklPoolingBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl_concat_op.cc b/tensorflow/core/kernels/mkl_concat_op.cc index ee1689b4561..314424e8930 100644 --- a/tensorflow/core/kernels/mkl_concat_op.cc +++ b/tensorflow/core/kernels/mkl_concat_op.cc @@ -18,7 +18,6 @@ limitations under the License. #include #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -32,6 +31,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::concat; using mkldnn::stream; @@ -266,7 +266,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { explicit MklConcatFwdPrimitive(const MklConcatFwdParams& concat_fwd_dims, const std::vector& srcs_md) : cpu_engine_(ENGINE_CPU, 0) { - context_.fwd_stream.reset(new CPU_STREAM(stream::kind::eager)); + context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); // Create concat primitive Setup(concat_fwd_dims, srcs_md); } @@ -292,8 +292,8 @@ class MklConcatFwdPrimitive : public MklPrimitive { } #ifdef ENABLE_MKLDNN_V1 - execute_primitives(context_.fwd_primitives, *context_.fwd_stream, - context_.fwd_primitives_args.at(i)); + execute_primitives(context_.fwd_primitives, context_.fwd_stream, + context_.fwd_primitives_args); #else context_.fwd_stream->submit(context_.fwd_primitives); #endif // ENABLE_MKLDNN_V1 @@ -328,7 +328,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { std::shared_ptr dst_mem; // Memory descriptor - std::vector> src_md; + std::vector src_md; std::shared_ptr dst_md; // Concat primitive descriptor @@ -339,7 +339,7 @@ class MklConcatFwdPrimitive : public MklPrimitive { std::vector fwd_primitives; #ifdef ENABLE_MKLDNN_V1 - std::vector> fwd_primitive_args; + std::vector> fwd_primitives_args; #endif // ENABLE_MKLDNN_V1 ConcatFwdContext() @@ -355,15 +355,14 @@ class MklConcatFwdPrimitive : public MklPrimitive { const std::vector& srcs_md) { // Create memory descriptors for concat with specified srcs format for (size_t i = 0; i < concat_fwd_dims.num_inputs; i++) { - std::shared_ptr source_md( - new memory::desc(srcs_md[i].data)); + mkldnn::memory::desc source_md(memory::desc(srcs_md[i].data)); context_.src_md.push_back(source_md); #ifdef ENABLE_MKLDNN_V1 std::shared_ptr src_mem( - new mkldnn::memory(*source_md, cpu_engine_, DummyData)); + new mkldnn::memory(source_md, cpu_engine_, DummyData)); #else std::shared_ptr src_mpd( - new memory::primitive_desc(*source_md, cpu_engine_)); + new memory::primitive_desc(source_md, cpu_engine_)); context_.src_pd_shdptr.push_back(src_mpd); std::shared_ptr src_mem( @@ -665,8 +664,9 @@ class MklConcatOp : public OpKernel { if (input_tensors[k].NumElements() == 0) continue; auto src_md = mkl_input_shapes[k].GetMklLayout(); srcs[k].SetUsrMem(src_md, &input_tensors[k]); - - if (src_md.data.format != mkl_common_format) { + auto src_tf_fmt = MklTensorFormatToMklDnnDataFormat( + mkl_input_shapes[k].GetTfDataFormat()); + if (src_tf_fmt != mkl_common_format) { memory::dims src_dims(src_md.data.dims, &src_md.data.dims[src_md.data.ndims]); src_md = @@ -935,7 +935,8 @@ class MklConcatOp : public OpKernel { for (int k = 0; k < input_shapes.size(); k++) { auto src_dims = TFShapeToMklDnnDims(input_shapes[k].GetTfShape()); *concat_dim_size += src_dims[concat_dim]; - int fmt = static_cast(input_shapes[k].GetMklLayout().data.format); + int fmt = static_cast( + MklTensorFormatToMklDnnDataFormat(input_shapes[k].GetTfDataFormat())); occurrence_map[fmt] += 1; } @@ -943,7 +944,7 @@ class MklConcatOp : public OpKernel { // this means that all inputs have a same format // return it with is_reorder_needed set false. return static_cast( - input_shapes[0].GetMklLayout().data.format); + MklTensorFormatToMklDnnDataFormat(input_shapes[0].GetTfDataFormat())); } // Input tensors have different formats. Thus, reorder is needed. @@ -970,7 +971,7 @@ class MklConcatOp : public OpKernel { .TypeConstraint("T") \ .HostMemory("concat_dim") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklConcatOp) \ + MklConcatOp); \ REGISTER_KERNEL_BUILDER( \ Name("_MklConcatV2") \ .Device(DEVICE_CPU) \ @@ -978,7 +979,7 @@ class MklConcatOp : public OpKernel { .TypeConstraint("Tidx") \ .HostMemory("axis") \ .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ - MklConcatOp) + MklConcatOp); TF_CALL_float(REGISTER_MKL_CPU); TF_CALL_bfloat16(REGISTER_MKL_CPU); @@ -988,14 +989,14 @@ REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") .TypeConstraint("T") .HostMemory("axis") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklConcatOp) + MklConcatOp); REGISTER_KERNEL_BUILDER(Name("_MklQuantizedConcatV2") .Device(DEVICE_CPU) .TypeConstraint("T") .HostMemory("axis") .Label(mkl_op_registry::kMklQuantizedOpLabel), - MklConcatOp) + MklConcatOp); #undef REGISTER_CONCAT_MKL } // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 71d59fb7971..9ce28c79a28 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -62,13 +62,19 @@ struct MklConvBwdFilterParams { memory::dims dilations; memory::dims padding_left; memory::dims padding_right; +#ifndef ENABLE_MKLDNN_V1 padding_kind padding; +#endif // !ENABLE_MKLDNN_V1 MklConvBwdFilterParams(memory::dims src_dims, memory::dims diff_filter_dims, memory::dims diff_bias_dims, memory::dims diff_dst_dims, memory::dims strides, memory::dims dilations, memory::dims padding_left, +#ifndef ENABLE_MKLDNN_V1 memory::dims padding_right, padding_kind padding) +#else + memory::dims padding_right) +#endif // !ENABLE_MKLDNN_V1 : src_dims(src_dims), diff_filter_dims(diff_filter_dims), diff_bias_dims(diff_bias_dims), @@ -76,8 +82,14 @@ struct MklConvBwdFilterParams { strides(strides), dilations(dilations), padding_left(padding_left), +#ifndef ENABLE_MKLDNN_V1 padding_right(padding_right), - padding(padding) {} + padding(padding) { + } +#else + padding_right(padding_right) { + } +#endif // !ENABLE_MKLDNN_V1 }; template @@ -241,8 +253,12 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md, convBwdFilterDims.strides, convBwdFilterDims.dilations, +#ifndef ENABLE_MKLDNN_V1 convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, convBwdFilterDims.padding)); +#else + convBwdFilterDims.padding_left, convBwdFilterDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); // Create descriptor and primitive descriptor for convolution bwd filter. @@ -252,14 +268,22 @@ class MklConvBwdFilterPrimitive : public MklPrimitive { *context_.diff_filter_md, *context_.diff_bias_md, *context_.diff_dst_md, convBwdFilterDims.strides, convBwdFilterDims.dilations, convBwdFilterDims.padding_left, +#ifndef ENABLE_MKLDNN_V1 convBwdFilterDims.padding_right, convBwdFilterDims.padding)); +#else + convBwdFilterDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 } else { context_.bwd_filter_desc.reset(new ConvBwdFilterDesc( ALGORITHM::convolution_direct, *context_.src_md, *context_.diff_filter_md, *context_.diff_dst_md, convBwdFilterDims.strides, convBwdFilterDims.dilations, +#ifndef ENABLE_MKLDNN_V1 convBwdFilterDims.padding_left, convBwdFilterDims.padding_right, convBwdFilterDims.padding)); +#else + convBwdFilterDims.padding_left, convBwdFilterDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 } context_.bwd_filter_pd.reset(new ConvBwdFilterPd( *context_.bwd_filter_desc, cpu_engine_, *context_.fwd_pd)); @@ -495,11 +519,14 @@ class MklConvCustomBackpropFilterOp // The default dilation factor for each dimension is 1 in TF and // 0 in MKL-DNN. for (int i = 0; i < dilations.size(); ++i) --dilations[i]; - MklConvBwdFilterParams convBwdFilterDims( fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides, +#ifndef ENABLE_MKLDNN_V1 dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); +#else + dilations, padding_left, padding_right); +#endif // !ENABLE_MKLDNN_V1 // MKL-DNN allocates large buffers when a conv gradient filter primtive is // created. So we don't cache conv backward primitives when the env diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index cf0683ffd4c..97dde74ea8f 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -65,20 +65,32 @@ struct MklConvBwdInputParams { memory::dims dilations; memory::dims padding_left; memory::dims padding_right; +#ifndef ENABLE_MKLDNN_V1 padding_kind padding; +#endif // !ENABLE_MKLDNN_V1 MklConvBwdInputParams(memory::dims diff_src_dims, memory::dims filter_dims, memory::dims diff_dst_dims, memory::dims strides, memory::dims dilations, memory::dims padding_left, +#ifndef ENABLE_MKLDNN_V1 memory::dims padding_right, padding_kind padding) +#else + memory::dims padding_right) +#endif // !ENABLE_MKLDNN_V1 : diff_src_dims(diff_src_dims), filter_dims(filter_dims), diff_dst_dims(diff_dst_dims), strides(strides), dilations(dilations), padding_left(padding_left), +#ifndef ENABLE_MKLDNN_V1 padding_right(padding_right), - padding(padding) {} + padding(padding) { + } +#else + padding_right(padding_right) { + } +#endif // !ENABLE_MKLDNN_V1 }; template @@ -211,14 +223,22 @@ class MklConvBwdInputPrimitive : public MklPrimitive { ALGORITHM::convolution_direct, *context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, convBwdInputDims.dilations, convBwdInputDims.padding_left, +#ifndef ENABLE_MKLDNN_V1 convBwdInputDims.padding_right, convBwdInputDims.padding)); +#else + convBwdInputDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 context_.fwd_desc.reset(new ConvFwdDesc( prop_kind::forward, ALGORITHM::convolution_direct, *context_.diff_src_md, *context_.filter_md, *context_.diff_dst_md, convBwdInputDims.strides, convBwdInputDims.dilations, +#ifndef ENABLE_MKLDNN_V1 convBwdInputDims.padding_left, convBwdInputDims.padding_right, convBwdInputDims.padding)); +#else + convBwdInputDims.padding_left, convBwdInputDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 // Create primitive descriptors for conv fwd and conv bwd input. context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); @@ -440,11 +460,14 @@ class MklConvCustomBackpropInputOp // The default dilation factor for each dimension is 1 in TF and // 0 in MKL-DNN. for (int i = 0; i < dilations.size(); ++i) --dilations[i]; - MklConvBwdInputParams convBwdInputDims( fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations, +#ifndef ENABLE_MKLDNN_V1 padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); +#else + padding_left, padding_right); +#endif // !ENABLE_MKLDNN_V1 // We don't cache those primitives if the environment variable // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index a7cc178bf86..00ec871f582 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -24,8 +24,8 @@ limitations under the License. #include #include -#include "mkldnn.hpp" #include "absl/strings/str_join.h" +#include "mkldnn.hpp" #include "tensorflow/core/framework/bounds_check.h" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" @@ -239,13 +239,21 @@ class MklConvFwdPrimitive : public MklPrimitive { prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, *context_.filter_md, *context_.bias_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, +#ifndef ENABLE_MKLDNN_V1 convFwdDims.padding_right, padding_kind::zero)); +#else + convFwdDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 } else { context_.fwd_desc.reset(new convolution_forward::desc( prop_kind::forward, ALGORITHM::convolution_direct, *context_.src_md, *context_.filter_md, *context_.dst_md, convFwdDims.strides, convFwdDims.dilations, convFwdDims.padding_left, +#ifndef ENABLE_MKLDNN_V1 convFwdDims.padding_right, padding_kind::zero)); +#else + convFwdDims.padding_right)); +#endif // !ENABLE_MKLDNN_V1 } context_.fwd_pd.reset(new ConvFwdPd(*context_.fwd_desc, cpu_engine_)); @@ -261,12 +269,7 @@ class MklConvFwdPrimitive : public MklPrimitive { float op_scale = post_op_param.param[0]; float op_alpha = post_op_param.param[1]; float op_beta = post_op_param.param[2]; -#ifdef ENABLE_MKLDNN_V1 - post_ops.append_eltwise(op_scale, mkldnn::algorithm::eltwise_relu, - op_alpha, -#else post_ops.append_eltwise(op_scale, post_op_param.alg, op_alpha, -#endif // ENABLE_MKLDNN_V1 op_beta); } else if (post_op_param.name == "sum") { DCHECK_EQ(post_op_param.param.size(), 1); @@ -1033,6 +1036,7 @@ class MklConvOp : public OpKernel { &cached_filter_data_ptensor_, filter_tensor)); Tensor* second_tensor = nullptr; +#ifndef ENABLE_MKLDNN_V1 TensorShape filter_mkl_format; filter_mkl_format.AddDim( sizeof(GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)) / @@ -1042,6 +1046,21 @@ class MklConvOp : public OpKernel { &cached_filter_md_ptensor_, &second_tensor)); second_tensor->scalar()() = static_cast( GetFilterTfDataFormat(filter_mkl_shape, conv_prim_desc)); +#else + // There is no tensor format in DNNL 1.x. So we cache the complete filter + // descriptor as flat byte array. + TensorShape cached_filter_md_shape; + memory::desc weights_desc = conv_prim_desc.weights_desc(); + // We don't use .get_size() method of memory::desc since it returns size + // required to store primitive's input memory. It is much more than size of + // memory::desc itself. + cached_filter_md_shape.AddDim(sizeof(weights_desc) / sizeof(uint8)); + OP_REQUIRES_OK(context, context->allocate_persistent( + DT_UINT8, cached_filter_md_shape, + &cached_filter_md_ptensor_, &second_tensor)); + *reinterpret_cast(second_tensor->flat().data()) = + weights_desc; +#endif // !ENABLE_MKLDNN_V1 } void AllocatePersistentTensor(OpKernelContext* context, @@ -1230,12 +1249,11 @@ class MklConvOp : public OpKernel { const Tensor& cached_filter_md = *cached_filter_md_ptensor_.AccessTensor(context); -// Check if the memory descriptor of the cached weights is same as -// filter_md. If so, we can use the cached weights; otherwise -// return nullptr. + // Check if the memory descriptor of the cached weights is same as + // filter_md. If so, we can use the cached weights; otherwise + // return nullptr. #ifdef ENABLE_MKLDNN_V1 - if (cached_filter_md.scalar().size() && - AreMemoryDescriptorsEqual(filter_md, cached_filter_md)) { + if (filter_md == *static_cast(cached_filter_md.data())) { #else if (cached_filter_md.scalar().size() && cached_filter_md.scalar()() == filter_md) { @@ -1568,7 +1586,7 @@ class MklQuantizedConv2DOp if (!scaled_bias_buf_) AllocTmpBuffer(context, &scaled_bias_tensor_, - conv_fwd_pd->bias_primitive_desc(), + GET_BIAS_DESC_FROM_OP_PD(conv_fwd_pd), &scaled_bias_buf_); if (!scaled_bias_) { scaled_bias_ = new MEMORY_CONSTRUCTOR(bias_md, this->cpu_engine_, diff --git a/tensorflow/core/kernels/mkl_dequantize_op.cc b/tensorflow/core/kernels/mkl_dequantize_op.cc index ca0632b7f12..8737581c726 100644 --- a/tensorflow/core/kernels/mkl_dequantize_op.cc +++ b/tensorflow/core/kernels/mkl_dequantize_op.cc @@ -89,12 +89,26 @@ class MklDequantizeOp : public OpKernel { Tensor* output_tensor = nullptr; MklDnnShape output_mkl_shape; TensorShape output_tf_shape; +#ifndef ENABLE_MKLDNN_V1 memory::desc dst_md = src_mkl_shape.IsMklTensor() ? memory::desc(src_dims, MklDnnType(), - static_cast(src_md.data.format)) + static_cast(src_md.data.format)) : memory::desc(src_dims, MklDnnType(), MEMORY_FORMAT::nhwc); +#else + memory::desc dst_md = memory::desc(); + if (src_mkl_shape.IsMklTensor()) { + dst_md = memory::desc(src_mkl_shape.GetMklLayout().data); + // There is no API in MKL-DNN v1.x to construct memory descriptor with + // same .data field but different type. + dst_md.data.data_type = memory::convert_to_c(MklDnnType()); + } else { + dst_md = + memory::desc(src_dims, MklDnnType(), MEMORY_FORMAT::nhwc); + } +#endif // !ENABLE_MKLDNN_V1 + // If input is MKL shape, output is also MKL shape. // If input is TF shape, output is also TF shape. if (src_mkl_shape.IsMklTensor()) { diff --git a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc index 6427f805874..30023f360cc 100644 --- a/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/mkl_fused_batch_norm_op.cc @@ -14,20 +14,25 @@ limitations under the License. ==============================================================================*/ #ifdef INTEL_MKL #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#define GET_FLAG(bn_flag) static_cast(BN_FLAGS::bn_flag) +#define IS_SET(cflag) (context_.flags & GET_FLAG(cflag)) using mkldnn::batch_normalization_backward; using mkldnn::batch_normalization_forward; using mkldnn::prop_kind; using mkldnn::stream; -using mkldnn::use_global_stats; -using mkldnn::use_scale_shift; + +using BatchNormFwdPd = mkldnn::batch_normalization_forward::primitive_desc; +using BatchNormBwdPd = mkldnn::batch_normalization_backward::primitive_desc; namespace tensorflow { using CPUDevice = Eigen::ThreadPoolDevice; @@ -37,23 +42,40 @@ struct MklBatchNormFwdParams { int depth; float eps; bool training; - memory::format src_format; +#ifndef ENABLE_MKLDNN_V1 + MEMORY_FORMAT src_format; +#else + memory::desc src_md; +#endif // !ENABLE_MKLDNN_V1 MklBatchNormFwdParams(const memory::dims& src_dims, int depth, float eps, - bool training, memory::format src_format) +#ifndef ENABLE_MKLDNN_V1 + bool training, MEMORY_FORMAT src_format) : src_dims(src_dims), depth(depth), eps(eps), training(training), - src_format(src_format) {} + src_format(src_format) { + } +#else + bool training, memory::desc src_md) + : src_dims(src_dims), + depth(depth), + eps(eps), + training(training), + src_md(src_md) { + } +#endif // !ENABLE_MKLDNN_V1 }; template class MklFusedBatchNormFwdPrimitive : public MklPrimitive { public: explicit MklFusedBatchNormFwdPrimitive(const MklBatchNormFwdParams& fwdParams) - : cpu_engine_(engine::cpu, 0) { -#ifndef ENABLE_MKLDNN_V1 + : cpu_engine_(ENGINE_CPU, 0) { +#ifdef ENABLE_MKLDNN_V1 + context_.fwd_stream.reset(new CPU_STREAM(cpu_engine_)); +#else context_.fwd_stream.reset( new mkldnn::stream(mkldnn::stream::kind::eager_nostore)); #endif @@ -74,68 +96,85 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { static_cast(const_cast(src_data))); context_.dst_mem->set_data_handle(static_cast(dst_data)); - if (context_.flags & use_scale_shift) + if (IS_SET(use_scale_shift)) context_.weights_mem->set_data_handle( static_cast(const_cast(weights_data))); if ((context_.pkind == prop_kind::forward_training) || - (context_.flags & use_global_stats)) { + (IS_SET(use_global_stats))) { context_.mean_mem->set_data_handle(static_cast(mean_data)); context_.variance_mem->set_data_handle(static_cast(variance_data)); } - - // execution +#ifdef ENABLE_MKLDNN_V1 + // Execute batch-normalization forward primitives. + execute_primitives(context_.fwd_primitives, context_.fwd_stream, + context_.net_args); +#else context_.fwd_stream->submit(context_.fwd_primitives); +#endif // ENABLE_MKLDNN_V1 context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); - if (context_.flags & use_scale_shift) + if (IS_SET(use_scale_shift)) context_.weights_mem->set_data_handle(DummyData); if ((context_.pkind == prop_kind::forward_training) || - (context_.flags & use_global_stats)) { + (IS_SET(use_global_stats))) { context_.mean_mem->set_data_handle(DummyData); context_.variance_mem->set_data_handle(DummyData); } } - memory::primitive_desc GetDstPd() const { - return (*context_.dst_mem).get_primitive_desc(); - } + MEMORY_PRIMITIVE_DESC GetDstPd() const { return context_.dst_mem->GET_DESC; } - mkldnn_memory_format_t GetSrcFmt() const { - return (*context_.src_mem).get_primitive_desc().desc().data.format; +#ifndef ENABLE_MKLDNN_V1 + // In MKL-DNN v1.x, memory format tags only provide a partial description + // of the memory layout. Hence, these functions are disabled for v1.x. + mkldnn_memory_format_t GetSrcMemoryFormat() const { + return context_.src_mem->get_primitive_desc().desc().data.format; } mkldnn_memory_format_t GetDstFmt() const { return (*context_.dst_mem).get_primitive_desc().desc().data.format; } +#endif // !ENABLE_MKLDNN_V1 + + std::shared_ptr GetBatchNormFwdPd() const { + return context_.fwd_pd; + } private: - // Primitive reuse context for BatchNorm fwd op + // Primitive reuse context for BatchNorm forward op. struct BatchNormFwdContext { - // flags indict if it is training or inference mode + // Flags indicating if it is training or inference mode. int64 flags; - // algorithm + // Algorithm kind. mkldnn::prop_kind pkind; - // Mkldnn Memory + // Inputs/outputs memory. std::shared_ptr src_mem; std::shared_ptr weights_mem; std::shared_ptr dst_mem; std::shared_ptr mean_mem; std::shared_ptr variance_mem; - // BatchNorm forward primitive + // Forward BatchNorm primitive descriptor. + std::shared_ptr fwd_pd; + + // BatchNorm forward primitive. std::shared_ptr bn_fwd; std::shared_ptr fwd_stream; std::vector fwd_primitives; +#ifdef ENABLE_MKLDNN_V1 + std::vector> net_args; +#endif // ENABLE_MKLDNN_V1 + BatchNormFwdContext() : flags(0), - pkind(mkldnn::forward_training), + pkind(prop_kind::forward_training), src_mem(nullptr), weights_mem(nullptr), dst_mem(nullptr), @@ -146,83 +185,143 @@ class MklFusedBatchNormFwdPrimitive : public MklPrimitive { }; void Setup(const MklBatchNormFwdParams& fwdParams) { - context_.flags = fwdParams.training ? use_scale_shift - : (use_scale_shift | use_global_stats); + context_.flags = + fwdParams.training + ? GET_FLAG(use_scale_shift) + : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); context_.pkind = fwdParams.training ? prop_kind::forward_training : prop_kind::forward_scoring; - // memory desc +#ifdef ENABLE_MKLDNN_V1 + // Memory descriptor + auto src_md = fwdParams.src_md; + // Create forward BatchNorm descriptor and primitive descriptor. + auto fwd_desc = batch_normalization_forward::desc( + context_.pkind, src_md, fwdParams.eps, + static_cast(context_.flags)); +#else + // Memory descriptor auto src_md = memory::desc({fwdParams.src_dims}, MklDnnType(), fwdParams.src_format); - - // fwd desc & primitive desc auto fwd_desc = batch_normalization_forward::desc( context_.pkind, src_md, fwdParams.eps, context_.flags); - auto fwd_pd = - batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); +#endif // ENABLE_MKLDNN_V1 - // memory primitive - context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); - context_.dst_mem.reset(new memory(fwd_pd.dst_primitive_desc(), DummyData)); + context_.fwd_pd.reset(new BatchNormFwdPd(fwd_desc, cpu_engine_)); - if (context_.flags & use_scale_shift) { - auto weights_desc = memory::desc({2, fwdParams.depth}, MklDnnType(), - memory::format::nc); - context_.weights_mem.reset( - new memory({weights_desc, cpu_engine_}, DummyData)); + // Create memory primitive based on dummy data + context_.src_mem.reset(new MEMORY_CONSTRUCTOR( + context_.fwd_pd->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData)); + context_.dst_mem.reset(new MEMORY_CONSTRUCTOR( + context_.fwd_pd->PRIMITIVE_DESC_DST, cpu_engine_, DummyData)); + + memory::dims s_dims = {2, fwdParams.depth}; + memory::dims m_dims = {1, fwdParams.depth}; + if (IS_SET(use_scale_shift)) { + context_.weights_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD( + s_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData)); } - if (fwdParams.training || (context_.flags & use_global_stats)) { - auto mean_desc = memory::desc({1, fwdParams.depth}, MklDnnType(), - memory::format::nc); - context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + if (fwdParams.training || (IS_SET(use_global_stats))) { + context_.mean_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD( + m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData)); - auto variance_desc = - memory::desc({1, fwdParams.depth}, MklDnnType(), memory::nc); - context_.variance_mem.reset( - new memory({variance_desc, cpu_engine_}, DummyData)); + context_.variance_mem.reset(new MEMORY_CONSTRUCTOR_USING_MEM_PD( + m_dims, U, MEMORY_FORMAT::nc, cpu_engine_, DummyData)); } - // BatchNorm forward primitive - if (!fwdParams.training && !(context_.flags & use_global_stats)) { - if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { + // BatchNorm forward primitive. + if (!fwdParams.training && !(IS_SET(use_global_stats))) { +#ifdef ENABLE_MKLDNN_V1 + if ((IS_SET(use_scale_shift)) && mkldnn_use_scaleshift) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); + } else { + context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); + } + context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); +#else + if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { context_.bn_fwd.reset(new batch_normalization_forward( - fwd_pd, *context_.src_mem, *context_.weights_mem, + *context_.fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem)); } else { context_.bn_fwd.reset(new batch_normalization_forward( - fwd_pd, *context_.src_mem, *context_.dst_mem)); + *context_.fwd_pd, *context_.src_mem, *context_.dst_mem)); } - } else if (context_.flags & use_global_stats) { - if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { +#endif // ENABLE_MKLDNN_V1 + } else if (IS_SET(use_global_stats)) { +#ifdef ENABLE_MKLDNN_V1 + if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); + } else { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + { MKLDNN_ARG_DST, + *context_.dst_mem }}); + } + context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); +#else + if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { context_.bn_fwd.reset(new batch_normalization_forward( - fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + *context_.fwd_pd, *context_.src_mem, + (const primitive::at)*context_.mean_mem, (const primitive::at)*context_.variance_mem, *context_.weights_mem, *context_.dst_mem)); } else { context_.bn_fwd.reset(new batch_normalization_forward( - fwd_pd, *context_.src_mem, (const primitive::at)*context_.mean_mem, + *context_.fwd_pd, *context_.src_mem, + (const primitive::at)*context_.mean_mem, (const primitive::at)*context_.variance_mem, *context_.dst_mem)); } +#endif // ENABLE_MKLDNN_V1 } else { - if ((context_.flags & use_scale_shift) && mkldnn_use_scaleshift) { +#ifdef ENABLE_MKLDNN_V1 + if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { + context_.net_args.push_back( + {{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + { MKLDNN_ARG_VARIANCE, + *context_.variance_mem }}); + } else { + context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_DST, *context_.dst_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + { MKLDNN_ARG_VARIANCE, + *context_.variance_mem }}); + } + context_.bn_fwd.reset(new batch_normalization_forward(*context_.fwd_pd)); +#else + if ((IS_SET(use_scale_shift)) && GET_FLAG(use_scale_shift)) { context_.bn_fwd.reset(new batch_normalization_forward( - fwd_pd, *context_.src_mem, *context_.weights_mem, *context_.dst_mem, - *context_.mean_mem, *context_.variance_mem)); + *context_.fwd_pd, *context_.src_mem, *context_.weights_mem, + *context_.dst_mem, *context_.mean_mem, *context_.variance_mem)); } else { context_.bn_fwd.reset(new batch_normalization_forward( - fwd_pd, *context_.src_mem, *context_.dst_mem, *context_.mean_mem, - *context_.variance_mem)); + *context_.fwd_pd, *context_.src_mem, *context_.dst_mem, + *context_.mean_mem, *context_.variance_mem)); } +#endif // ENABLE_MKLDNN_V1 } context_.fwd_primitives.push_back(*context_.bn_fwd); } - mkldnn::memory::desc get_desc_data(const mkldnn::memory& m) const { - return m.get_primitive_desc().desc().data; - } - struct BatchNormFwdContext context_; engine cpu_engine_; }; @@ -284,25 +383,45 @@ struct MklBatchNormBwdParams { int depth; float eps; bool training; - memory::format src_format; +#ifndef ENABLE_MKLDNN_V1 + MEMORY_FORMAT src_format; +#else + memory::desc src_md; + memory::desc diff_dst_md; +#endif // !ENABLE_MKLDNN_V1 MklBatchNormBwdParams(memory::dims src_dims, memory::dims diff_dst_dims, int depth, float eps, bool training, - memory::format src_format) +#ifndef ENABLE_MKLDNN_V1 + MEMORY_FORMAT src_format) : src_dims(src_dims), diff_dst_dims(diff_dst_dims), depth(depth), eps(eps), training(training), - src_format(src_format) {} + src_format(src_format) { + } +#else + memory::desc src_md, memory::desc diff_dst_md) + : src_dims(src_dims), + diff_dst_dims(diff_dst_dims), + depth(depth), + eps(eps), + training(training), + src_md(src_md), + diff_dst_md(diff_dst_md) { + } +#endif // !ENABLE_MKLDNN_V1 }; template class MklFusedBatchNormBwdPrimitive : public MklPrimitive { public: explicit MklFusedBatchNormBwdPrimitive(const MklBatchNormBwdParams& bwdParams) - : cpu_engine_(engine::cpu, 0) { -#ifndef ENABLE_MKLDNN_V1 + : cpu_engine_(ENGINE_CPU, 0) { +#ifdef ENABLE_MKLDNN_V1 + context_.bwd_stream.reset(new CPU_STREAM(cpu_engine_)); +#else context_.bwd_stream.reset( new mkldnn::stream(mkldnn::stream::kind::eager_nostore)); #endif @@ -335,8 +454,7 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { context_.diff_dst_mem->set_data_handle( static_cast(const_cast(diff_dst_data))); - // TODO: type for weights? - if (context_.flags & use_scale_shift) { + if (IS_SET(use_scale_shift)) { context_.weights_mem->set_data_handle( static_cast(const_cast(weights_data))); context_.diff_weights_mem->set_data_handle( @@ -345,38 +463,53 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { context_.diff_src_mem->set_data_handle(static_cast(diff_src_data)); - // execution +#ifdef ENABLE_MKLDNN_V1 + // Execute backward batch-normalization primitives. + DCHECK_EQ(context_.bwd_primitives.size(), context_.net_args.size()); + for (size_t i = 0; i < context_.bwd_primitives.size(); ++i) { + context_.bwd_primitives.at(i).execute(*context_.bwd_stream, + context_.net_args.at(i)); + } +#else context_.bwd_stream->submit(context_.bwd_primitives); +#endif // ENABLE_MKLDNN_V1 + // After execution, set data handle back to DummyData. context_.src_mem->set_data_handle(DummyData); context_.mean_mem->set_data_handle(DummyData); context_.variance_mem->set_data_handle(DummyData); context_.diff_dst_mem->set_data_handle(DummyData); - if (context_.flags & use_scale_shift) { + if (IS_SET(use_scale_shift)) { context_.weights_mem->set_data_handle(DummyData); context_.diff_weights_mem->set_data_handle(DummyData); } context_.diff_src_mem->set_data_handle(DummyData); } - mkldnn_memory_format_t GetSrcFmt() { - return (*context_.src_mem).get_primitive_desc().desc().data.format; +#ifndef ENABLE_MKLDNN_V1 + mkldnn_memory_format_t GetSrcMemoryFormat() const { + return context_.src_mem->get_primitive_desc().desc().data.format; } - mkldnn_memory_format_t GetDiffDstFmt() { - return (*context_.diff_dst_mem).get_primitive_desc().desc().data.format; + mkldnn_memory_format_t GetDiffDstMemoryFormat() const { + return context_.diff_dst_mem->get_primitive_desc().desc().data.format; + } +#endif // !ENABLE_MKLDNN_V1 + + std::shared_ptr GetBatchNormBwdPd() const { + return context_.bwd_pd; } - memory::primitive_desc GetDiffSrcPd() { - return (*context_.diff_src_mem).get_primitive_desc(); + MEMORY_PRIMITIVE_DESC GetDiffSrcPd() { + return GET_MEMORY_PRIMITIVE_DESC_FROM_MEM_PTR(context_.diff_src_mem); } private: struct BatchNormBwdContext { - // Flags to indicate whether it is training or inference + // Flags to indicate whether it is training or inference. int64 flags; - // MKLDNN memory + // Inputs/output memory. std::shared_ptr src_mem; std::shared_ptr mean_mem; std::shared_ptr variance_mem; @@ -385,11 +518,18 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { std::shared_ptr diff_weights_mem; std::shared_ptr diff_src_mem; - // Batch Norm primitive + // Backward batch-normalization primitive descriptor. + std::shared_ptr bwd_pd; + + // Backward batch-normalization primitive. std::shared_ptr bn_bwd; std::vector bwd_primitives; std::shared_ptr bwd_stream; +#ifdef ENABLE_MKLDNN_V1 + std::vector> net_args; +#endif // ENABLE_MKLDNN_V1 + BatchNormBwdContext() : src_mem(nullptr), mean_mem(nullptr), @@ -402,60 +542,80 @@ class MklFusedBatchNormBwdPrimitive : public MklPrimitive { }; void Setup(const MklBatchNormBwdParams& bwdParams) { - context_.flags = bwdParams.training ? use_scale_shift - : (use_scale_shift | use_global_stats); + context_.flags = + bwdParams.training + ? GET_FLAG(use_scale_shift) + : (GET_FLAG(use_scale_shift) | GET_FLAG(use_global_stats)); - // memory desc + // Memory descriptors. +#ifndef ENABLE_MKLDNN_V1 auto src_md = memory::desc({bwdParams.src_dims}, MklDnnType(), bwdParams.src_format); auto diff_dst_md = memory::desc({bwdParams.diff_dst_dims}, MklDnnType(), bwdParams.src_format); +#else + auto src_md = bwdParams.src_md; + auto diff_dst_md = bwdParams.diff_dst_md; +#endif // !ENABLE_MKLDNN_V1 auto variance_desc = - memory::desc({1, bwdParams.depth}, MklDnnType(), memory::nc); + memory::desc({1, bwdParams.depth}, MklDnnType(), MEMORY_FORMAT::nc); auto mean_desc = - memory::desc({1, bwdParams.depth}, MklDnnType(), memory::format::nc); + memory::desc({1, bwdParams.depth}, MklDnnType(), MEMORY_FORMAT::nc); auto weights_desc = - memory::desc({2, bwdParams.depth}, MklDnnType(), memory::format::nc); + memory::desc({2, bwdParams.depth}, MklDnnType(), MEMORY_FORMAT::nc); auto diff_weights_desc = weights_desc; - // fwd desc & primitive desc + // Forward batch-normalization descriptor and primitive descriptor. + // Adding this back due to type difference with context.flags + auto bn_flags = + bwdParams.training + ? BN_FLAGS::use_scale_shift + : (BN_FLAGS::use_scale_shift | BN_FLAGS::use_global_stats); auto fwd_desc = batch_normalization_forward::desc( - prop_kind::forward_training, src_md, bwdParams.eps, - bwdParams.training ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto fwd_pd = - batch_normalization_forward::primitive_desc(fwd_desc, cpu_engine_); + prop_kind::forward_training, src_md, bwdParams.eps, bn_flags); + auto fwd_pd = BatchNormFwdPd(fwd_desc, cpu_engine_); - // BatchNorm backward primtive - // + // Backward batch-normalization primitive. // For inference, specify use_global_stats // 1. on fwd propagation, use mean and variance provided as inputs. // 2. on bwd propagation, mean and variance are considered as constants. // Thus, reduce the amount of MKL computation. auto bwd_desc = batch_normalization_backward::desc( - prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, - bwdParams.training ? use_scale_shift - : (use_scale_shift | use_global_stats)); - auto bn_bwd_pd = batch_normalization_backward::primitive_desc( - bwd_desc, cpu_engine_, fwd_pd); + prop_kind::backward, diff_dst_md, src_md, bwdParams.eps, bn_flags); + context_.bwd_pd.reset(new BatchNormBwdPd(bwd_desc, cpu_engine_, fwd_pd)); - // memory primitive - context_.src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + // Create memory primitives. + context_.src_mem.reset( + new MEMORY_CONSTRUCTOR_USING_MD(src_md, cpu_engine_, DummyData)); context_.diff_dst_mem.reset( - new memory({diff_dst_md, cpu_engine_}, DummyData)); + new MEMORY_CONSTRUCTOR_USING_MD(diff_dst_md, cpu_engine_, DummyData)); context_.variance_mem.reset( - new memory({variance_desc, cpu_engine_}, DummyData)); - context_.mean_mem.reset(new memory({mean_desc, cpu_engine_}, DummyData)); + new MEMORY_CONSTRUCTOR_USING_MD(variance_desc, cpu_engine_, DummyData)); + context_.mean_mem.reset( + new MEMORY_CONSTRUCTOR_USING_MD(mean_desc, cpu_engine_, DummyData)); context_.weights_mem.reset( - new memory({weights_desc, cpu_engine_}, DummyData)); - context_.diff_weights_mem.reset( - new memory({diff_weights_desc, cpu_engine_}, DummyData)); - context_.diff_src_mem.reset(new memory({src_md, cpu_engine_}, DummyData)); + new MEMORY_CONSTRUCTOR_USING_MD(weights_desc, cpu_engine_, DummyData)); + context_.diff_weights_mem.reset(new MEMORY_CONSTRUCTOR_USING_MD( + diff_weights_desc, cpu_engine_, DummyData)); + context_.diff_src_mem.reset( + new MEMORY_CONSTRUCTOR_USING_MD(src_md, cpu_engine_, DummyData)); +#ifdef ENABLE_MKLDNN_V1 + context_.bn_bwd.reset(new batch_normalization_backward(*context_.bwd_pd)); + context_.net_args.push_back({{MKLDNN_ARG_SRC, *context_.src_mem}, + {MKLDNN_ARG_MEAN, *context_.mean_mem}, + {MKLDNN_ARG_VARIANCE, *context_.variance_mem}, + {MKLDNN_ARG_DIFF_DST, *context_.diff_dst_mem}, + {MKLDNN_ARG_WEIGHTS, *context_.weights_mem}, + {MKLDNN_ARG_DIFF_SRC, *context_.diff_src_mem}, + { MKLDNN_ARG_DIFF_WEIGHTS, + *context_.diff_weights_mem }}); +#else context_.bn_bwd.reset(new batch_normalization_backward( - bn_bwd_pd, *context_.src_mem, *context_.mean_mem, + *context_.bwd_pd, *context_.src_mem, *context_.mean_mem, *context_.variance_mem, *context_.diff_dst_mem, *context_.weights_mem, *context_.diff_src_mem, *context_.diff_weights_mem)); +#endif // ENABLE_MKLDNN_V1 context_.bwd_primitives.push_back(*context_.bn_bwd); } @@ -590,7 +750,7 @@ class MklFusedBatchNormOp : public OpKernel { est_variance_tensor.shape().DebugString())); } - // special case: input with 0 element and 0 batch size + // Handle the special case: input with 0 element and 0 batch size. Tensor* dst_tensor = nullptr; if (tf_shape_src.num_elements() == 0) { HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), @@ -603,10 +763,10 @@ class MklFusedBatchNormOp : public OpKernel { else ExtractParams(context); - // Indices of output tensors + // Index of output tensor(diff_src). const size_t kDstIndex = 0; - // allocate 4 output TF tensors + // Allocate 4 output TF tensors. Tensor* batch_mean_tensor = nullptr; Tensor* batch_variance_tensor = nullptr; Tensor* saved_mean_tensor = nullptr; @@ -621,21 +781,25 @@ class MklFusedBatchNormOp : public OpKernel { else SetMeanVariance(est_mean_tensor, est_variance_tensor); - MklDnnData src(&cpu_engine); - MklDnnData weights(&cpu_engine); + MklDnnData src(&cpu_engine_); + MklDnnData weights(&cpu_engine_); - memory::format format_m; + MEMORY_FORMAT dnn_fmt; + MKL_TENSOR_FORMAT mkl_tensor_fmt; if (dnn_shape_src.IsMklTensor()) { if (dnn_shape_src.IsTensorInNCHWFormat()) { - format_m = memory::format::nchw; + dnn_fmt = MEMORY_FORMAT::nchw; + mkl_tensor_fmt = MKL_TENSOR_FORMAT_NCHW; } else { - format_m = memory::format::nhwc; + dnn_fmt = MEMORY_FORMAT::nhwc; + mkl_tensor_fmt = MKL_TENSOR_FORMAT_NHWC; } } else { - format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); + mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); + dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); } - // set src primitive + // Set src memory descriptor. memory::dims src_dims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetSizesAsMklDnnDims() @@ -643,7 +807,7 @@ class MklFusedBatchNormOp : public OpKernel { auto src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), format_m); + : memory::desc(src_dims, MklDnnType(), dnn_fmt); // MKL-DNN packs scale & shift as "weights": // ...... @@ -665,16 +829,21 @@ class MklFusedBatchNormOp : public OpKernel { reinterpret_cast(variance_values_), depth_ * sizeof(U)); - // get batchnorm op from the pool +#ifdef ENABLE_MKLDNN_V1 + MklBatchNormFwdParams fwdParams(src_dims, depth_, epsilon_, is_training_, + src_md); +#else MklBatchNormFwdParams fwdParams( src_dims, depth_, epsilon_, is_training_, - static_cast(src_md.data.format)); + static_cast(src_md.data.format)); +#endif // ENABLE_MKLDNN_V1 + // Get forward batch-normalization op from the primitive caching pool. MklFusedBatchNormFwdPrimitive* bn_fwd = MklFusedBatchNormFwdPrimitiveFactory::Get(fwdParams); const T* src_data = src_tensor.flat().data(); - // allocate output (dst) tensor; always set it as MKL-DNN layout + // Allocate output (dst) tensor; always set it as MKL-DNN layout MklDnnShape dnn_shape_dst; TensorShape tf_shape_dst; dnn_shape_dst.SetMklTensor(true); @@ -683,7 +852,7 @@ class MklFusedBatchNormOp : public OpKernel { dnn_shape_dst.SetElemType(MklDnnType()); auto ndims = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetDimension() : src_tensor.shape().dims(); - dnn_shape_dst.SetTfLayout(ndims, src_dims, format_m); + dnn_shape_dst.SetTfLayout(ndims, src_dims, mkl_tensor_fmt); tf_shape_dst.AddDim(dst_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDstIndex, &dst_tensor, tf_shape_dst, dnn_shape_dst); @@ -693,19 +862,17 @@ class MklFusedBatchNormOp : public OpKernel { U* variance_op_data = saved_variance_tensor->flat().data(); T* dst_data = dst_tensor->flat().data(); - // execution + // Execute bn_fwd->Execute(src_data, weights_op_data, dst_data, mean_op_data, variance_op_data); - // copy batch_mean data + // Copy batch_mean data U* batch_mean_data_tf = batch_mean_tensor->flat().data(); std::memcpy(reinterpret_cast(batch_mean_data_tf), reinterpret_cast(saved_mean_data_tf), depth_ * sizeof(U)); - // TODO(yli135): OpMem is same as usr mem since - // since its format is hard-coded as nc when primitive is created. - // copy batch_variance data with Bessel's correction + // Copy batch_variance data with Bessel's correction. float adjust_factor = 1.0; if (is_training_) { size_t orig_size = src_dims[0] * src_dims[2] * src_dims[3]; @@ -739,8 +906,8 @@ class MklFusedBatchNormOp : public OpKernel { bool is_training_; U* mean_values_; U* variance_values_; - size_t depth_; // batch normalization is done for per channel. - engine cpu_engine = engine(engine::cpu, 0); + size_t depth_; // Batch normalization is performed for per channel. + engine cpu_engine_ = engine(ENGINE_CPU, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -755,14 +922,14 @@ class MklFusedBatchNormOp : public OpKernel { void HandleEmptyInput(OpKernelContext* context, TensorShape tf_shape_src, TensorShape tf_shape_scale, Tensor** dst_tensor) { - CHECK_NOTNULL(dst_tensor); + DCHECK(dst_tensor); const size_t kDstIndex = 0; MklDnnShape dnn_shape_dst; dnn_shape_dst.SetMklTensor(false); AllocateOutputSetMklShape(context, kDstIndex, dst_tensor, tf_shape_src, dnn_shape_dst); - CHECK_NOTNULL(*dst_tensor); + DCHECK(*dst_tensor); memset(const_cast((*dst_tensor)->tensor_data().data()), 0, (*dst_tensor)->tensor_data().size()); @@ -782,10 +949,10 @@ class MklFusedBatchNormOp : public OpKernel { Tensor** saved_mean_tensor, Tensor** saved_variance_tensor, Tensor** reserved_space_tensor) { - CHECK_NOTNULL(batch_mean_tensor); - CHECK_NOTNULL(batch_variance_tensor); - CHECK_NOTNULL(saved_mean_tensor); - CHECK_NOTNULL(saved_variance_tensor); + DCHECK(batch_mean_tensor); + DCHECK(batch_variance_tensor); + DCHECK(saved_mean_tensor); + DCHECK(saved_variance_tensor); const size_t kBatchMeanIndex = 1; const size_t kBatchVarianceIndex = 2; @@ -793,36 +960,38 @@ class MklFusedBatchNormOp : public OpKernel { const size_t kSavedVarianceIndex = 4; const size_t kReservedSpaceIndex = 5; - // allocate batch mean output tensor + // Allocate batch mean output tensor. MklDnnShape mkl_shape_batch_mean; mkl_shape_batch_mean.SetMklTensor(false); AllocateOutputSetMklShape(context, kBatchMeanIndex, batch_mean_tensor, tf_shape_scale, mkl_shape_batch_mean); - CHECK_NOTNULL(*batch_mean_tensor); - // set NAN mean value in case of empty input tensor + DCHECK(*batch_mean_tensor); + + // Set NAN mean value in case of empty input tensor int num_elements = tf_shape_scale.num_elements(); auto batch_mean_data = (*batch_mean_tensor)->flat().data(); std::fill_n(batch_mean_data, num_elements, static_cast(NAN)); - // allocate batch variance output tensor + // Allocate batch variance output tensor. MklDnnShape mkl_shape_batch_variance; mkl_shape_batch_variance.SetMklTensor(false); AllocateOutputSetMklShape(context, kBatchVarianceIndex, batch_variance_tensor, tf_shape_scale, mkl_shape_batch_variance); - CHECK_NOTNULL(*batch_variance_tensor); - // set NAN variance value in case of empty input tensor + DCHECK(*batch_variance_tensor); + + // Set NAN variance value in case of empty input tensor auto batch_variance_data = (*batch_variance_tensor)->flat().data(); std::fill_n(batch_variance_data, num_elements, static_cast(NAN)); - // Mean and variance (without Bessel's correction) saved for backward // computation to serve as pre-computed mean and variance. MklDnnShape mkl_shape_saved_mean; mkl_shape_saved_mean.SetMklTensor(false); AllocateOutputSetMklShape(context, kSavedMeanIndex, saved_mean_tensor, tf_shape_scale, mkl_shape_saved_mean); - CHECK_NOTNULL(*saved_mean_tensor); - // set 0 mean value in case of empty input tensor + DCHECK(*saved_mean_tensor); + + // Set NAN mean value in case of empty input tensor auto saved_mean_data = (*saved_mean_tensor)->flat().data(); std::fill_n(saved_mean_data, num_elements, static_cast(0)); @@ -831,15 +1000,16 @@ class MklFusedBatchNormOp : public OpKernel { AllocateOutputSetMklShape(context, kSavedVarianceIndex, saved_variance_tensor, tf_shape_scale, mkl_shape_saved_variance); - CHECK_NOTNULL(*saved_variance_tensor); - // set 0 variance value in case of empty input tensor + DCHECK(*saved_variance_tensor); + + // Set NAN variance value in case of empty input tensor auto saved_variance_data = (*saved_variance_tensor)->flat().data(); std::fill_n(saved_variance_data, num_elements, static_cast(0)); // Changes to support reserved_space_3 parameter in FusedBatchNormV3. // TODO: This parameter functionality is not implemented on CPU. // It is used to hold intermediate results. So the allocated - // memory is filled with 0. + // memory is filled with 0s. if (reserved_space) { DCHECK(reserved_space_tensor != nullptr); @@ -935,8 +1105,8 @@ class MklFusedBatchNormGradOp : public OpKernel { errors::InvalidArgument("saved variance must be 1-dimensional", saved_variance_tensor.shape().DebugString())); + // Handle the special case: input with 0 element and 0 batch size. Tensor* diff_src_tensor = nullptr; - // special case: input with 0 element and 0 batch size if (tf_shape_src.num_elements() == 0 || tf_shape_diff_dst.num_elements() == 0) { HandleEmptyInput(context, tf_shape_src, scale_tensor.shape(), @@ -952,20 +1122,25 @@ class MklFusedBatchNormGradOp : public OpKernel { ExtractParams(context); } - memory::format format_m; + MEMORY_FORMAT dnn_fmt; + MKL_TENSOR_FORMAT mkl_tensor_fmt; if (dnn_shape_src.IsMklTensor()) { - if (dnn_shape_src.IsTensorInNCHWFormat()) - format_m = memory::format::nchw; - else - format_m = memory::format::nhwc; + if (dnn_shape_src.IsTensorInNCHWFormat()) { + dnn_fmt = MEMORY_FORMAT::nchw; + mkl_tensor_fmt = MKL_TENSOR_FORMAT_NCHW; + } else { + dnn_fmt = MEMORY_FORMAT::nhwc; + mkl_tensor_fmt = MKL_TENSOR_FORMAT_NHWC; + } } else { - format_m = TFDataFormatToMklDnnDataFormat(tensor_format_); + mkl_tensor_fmt = TFDataFormatToMklDnnDataFormat(tensor_format_); + dnn_fmt = MklTensorFormatToMklDnnDataFormat(mkl_tensor_fmt); } - MklDnnData src(&cpu_engine); - MklDnnData diff_dst(&cpu_engine); - MklDnnData weights(&cpu_engine); - MklDnnData diff_weights(&cpu_engine); + MklDnnData src(&cpu_engine_); + MklDnnData diff_dst(&cpu_engine_); + MklDnnData weights(&cpu_engine_); + MklDnnData diff_weights(&cpu_engine_); memory::dims src_dims = dnn_shape_src.IsMklTensor() @@ -977,15 +1152,15 @@ class MklFusedBatchNormGradOp : public OpKernel { : TFShapeToMklDnnDimsInNCHW(diff_dst_tensor.shape(), tensor_format_); - // set src and diff_dst primitive descriptors + // Set src and diff_dst primitive descriptors. memory::desc src_md = dnn_shape_src.IsMklTensor() ? dnn_shape_src.GetMklLayout() - : memory::desc(src_dims, MklDnnType(), format_m); + : memory::desc(src_dims, MklDnnType(), dnn_fmt); memory::desc diff_dst_md = dnn_shape_diff_dst.IsMklTensor() ? dnn_shape_diff_dst.GetMklLayout() - : memory::desc(diff_dst_dims, MklDnnType(), format_m); + : memory::desc(diff_dst_dims, MklDnnType(), dnn_fmt); // weights -- MKL DNN packs scales/ shifts as weights in order // of scale, ..., scale, shift, ...., shift @@ -999,38 +1174,42 @@ class MklFusedBatchNormGradOp : public OpKernel { diff_weights.AllocateBuffer(2 * depth_ * sizeof(U)); +#ifdef ENABLE_MKLDNN_V1 + MklBatchNormBwdParams bwdParams(src_dims, diff_dst_dims, depth_, epsilon_, + is_training_, src_md, diff_dst_md); +#else MklBatchNormBwdParams bwdParams( src_dims, diff_dst_dims, depth_, epsilon_, is_training_, - static_cast(src_md.data.format)); + static_cast(src_md.data.format)); +#endif // ENABLE_MKLDNN_V1 MklFusedBatchNormBwdPrimitive* bn_bwd = MklFusedBatchNormBwdPrimitiveFactory::Get(bwdParams); const T* src_data = src_tensor.flat().data(); const T* diff_dst_data = diff_dst_tensor.flat().data(); // Check if diff_dst input needs to be reordered - if (diff_dst_md.data.format != bn_bwd->GetDiffDstFmt()) { + std::shared_ptr bn_bwd_pd = bn_bwd->GetBatchNormBwdPd(); + if (IS_DIFF_DST_REORDER_NEEDED(diff_dst_md, bn_bwd_pd, bn_bwd)) { diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor); - auto diff_dst_target = memory::primitive_desc( - {{diff_dst_dims}, - MklDnnType(), - static_cast(bn_bwd->GetDiffDstFmt())}, - cpu_engine); - diff_dst.CheckReorderToOpMem(diff_dst_target); - diff_dst_data = const_cast( - reinterpret_cast(diff_dst.GetOpMem().get_data_handle())); + diff_dst.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + GET_DIFF_DST_DESC_FROM_OP_PD(bn_bwd_pd), cpu_engine_)); + diff_dst_data = static_cast(diff_dst.GetOpMem().get_data_handle()); + } else { + diff_dst_data = + static_cast(const_cast(diff_dst_tensor.flat().data())); } // Indices of output tensors - const size_t kDiffSrcIndex = 0; // index of diff_src tensor + const size_t kDiffSrcIndex = 0; - // allocate output tensor: diff_src, always set as MKL-DNN layout + // Allocate output tensor diff_src, always set as MKL-DNN layout. MklDnnShape dnn_shape_diff_src; TensorShape tf_shape_diff_src; dnn_shape_diff_src.SetMklTensor(true); auto diff_src_pd = bn_bwd->GetDiffSrcPd(); dnn_shape_diff_src.SetMklLayout(&diff_src_pd); dnn_shape_diff_src.SetElemType(MklDnnType()); - dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, format_m); + dnn_shape_diff_src.SetTfLayout(src_dims.size(), src_dims, mkl_tensor_fmt); dnn_shape_diff_src.SetTfDimOrder(src_dims.size(), tensor_format_); tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T)); AllocateOutputSetMklShape(context, kDiffSrcIndex, &diff_src_tensor, @@ -1054,13 +1233,13 @@ class MklFusedBatchNormGradOp : public OpKernel { weights_data, diff_src_data, diff_weights_data, res_space_data); - // allocate output TF tensors: diff_scale and diff_shift + // Allocate output TF tensors diff_scale and diff_shift. Tensor* diff_scale_tensor = nullptr; Tensor* diff_shift_tensor = nullptr; AllocateTFOutputs(context, scale_tensor.shape(), &diff_scale_tensor, &diff_shift_tensor); - // copy data: diff_scale and diff_shift + // Copy data for tensors diff_scale and diff_shift. auto diff_scale_data = diff_scale_tensor->flat().data(); auto diff_shift_data = diff_shift_tensor->flat().data(); std::memcpy(reinterpret_cast(diff_scale_data), @@ -1082,9 +1261,9 @@ class MklFusedBatchNormGradOp : public OpKernel { private: float epsilon_; TensorFormat tensor_format_; - size_t depth_; // batch normalization is done for per channel. + size_t depth_; // Batch normalization is performed for per channel. bool is_training_; - engine cpu_engine = engine(engine::cpu, 0); + engine cpu_engine_ = engine(ENGINE_CPU, 0); void ExtractParams(OpKernelContext* context) { const Tensor& input = MklGetInput(context, 0); @@ -1114,20 +1293,21 @@ class MklFusedBatchNormGradOp : public OpKernel { TensorShape tf_shape_scale_shift, Tensor** diff_scale_tensor, Tensor** diff_shift_tensor) { - CHECK_NOTNULL(diff_scale_tensor); - CHECK_NOTNULL(diff_shift_tensor); + DCHECK(diff_scale_tensor); + DCHECK(diff_shift_tensor); const size_t kDiffScaleIndex = 1; const size_t kDiffShiftIndex = 2; const size_t kP1Index = 3; const size_t kP2Index = 4; - // separate out scale and shift grad and copy to individual tensors + // Separate out scale and shift grad and copy to individual tensors MklDnnShape mkl_shape_diff_scale; mkl_shape_diff_scale.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffScaleIndex, diff_scale_tensor, tf_shape_scale_shift, mkl_shape_diff_scale); - CHECK_NOTNULL(*diff_scale_tensor); + DCHECK(*diff_scale_tensor); + auto diff_scale_data = (*diff_scale_tensor)->flat().data(); std::fill_n(diff_scale_data, (*diff_scale_tensor)->shape().num_elements(), static_cast(0)); @@ -1136,7 +1316,8 @@ class MklFusedBatchNormGradOp : public OpKernel { mkl_shape_diff_shift.SetMklTensor(false); AllocateOutputSetMklShape(context, kDiffShiftIndex, diff_shift_tensor, tf_shape_scale_shift, mkl_shape_diff_shift); - CHECK_NOTNULL(*diff_shift_tensor); + DCHECK(*diff_shift_tensor); + auto diff_shift_data = (*diff_shift_tensor)->flat().data(); std::fill_n(diff_shift_data, (*diff_shift_tensor)->shape().num_elements(), static_cast(0)); @@ -1148,12 +1329,16 @@ class MklFusedBatchNormGradOp : public OpKernel { mkl_shape_p.SetMklTensor(false); AllocateOutputSetMklShape(context, kP1Index, &p1_tensor, TensorShape({}), mkl_shape_p); +#ifndef ENABLE_MKLDNN_V1 std::fill_n(p1_tensor->flat().data(), p1_tensor->shape().num_elements(), static_cast(0)); +#endif // !ENABLE_MKLDNN_V1 AllocateOutputSetMklShape(context, kP2Index, &p2_tensor, TensorShape({}), mkl_shape_p); +#ifndef ENABLE_MKLDNN_V1 std::fill_n(p2_tensor->flat().data(), p2_tensor->shape().num_elements(), static_cast(0)); +#endif // !ENABLE_MKLDNN_V1 } memory::dims GetMeanVarianceDims() { return memory::dims({1, depth_}); } @@ -1240,4 +1425,7 @@ REGISTER_MKL_FUSED_BATCHNORM_GRAD_V3_CPU(bfloat16, float); } // namespace tensorflow +#undef GET_FLAG +#undef IS_SET + #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index a794f7567bb..18f6667fd1e 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -105,6 +105,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { memory::dims weight_dims = memory::dims({channel, k}); memory::dims bias_dims = memory::dims({channel}); memory::dims dst_dims = memory::dims({batch, channel}); + MEMORY_FORMAT src_format = MEMORY_FORMAT::nc; MEMORY_FORMAT weight_format = transpose_b_ ? MEMORY_FORMAT::oi : MEMORY_FORMAT::io; @@ -112,7 +113,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // 1. const, let MKL-DNN determine format because it will be cached; // 2. var, keep the original format to avoid reordering. MklDnnMatMulFwdParams matmul_params( - src_dims, weight_dims, bias_dims, dst_dims, + src_dims, weight_dims, bias_dims, dst_dims, src_format, (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format); // Extend the basic parameters for data types and fusions. @@ -152,44 +153,44 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { MklDnnData src_mkl(&(this->cpu_engine_)); MklDnnData weight_mkl(&(this->cpu_engine_)); - if (src_mkl_shape.IsMklTensor()) { - memory::desc input_md = src_mkl_shape.GetMklLayout(); -#ifdef ENABLE_MKLDNN_V1 - if (input_md != matmul_pd->src_desc()) { -#else - if (input_md.data.format != MKL_TENSOR_FORMAT_NC) { -#endif // ENABLE_MKLDNN_V1 - src_mkl.SetUsrMem(input_md, src_data); - src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( - matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_)); - src_data = reinterpret_cast(src_mkl.GetOpMem().get_data_handle()); - } + auto src_md = src_mkl_shape.IsMklTensor() + ? src_mkl_shape.GetMklLayout() + : memory::desc(src_dims, MklDnnType(), src_format); + + if (IS_SRC_REORDER_NEEDED(src_md, matmul_pd, matmul_prim)) { + src_mkl.SetUsrMem(src_md, src_data); + src_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + matmul_pd.get()->PRIMITIVE_DESC_SRC, this->cpu_engine_)); + src_data = reinterpret_cast(src_mkl.GetOpMem().get_data_handle()); } // Get cached data when weight is const. - memory::format expected_format = matmul_prim->GetWeightMemoryFormat(); - DCHECK(expected_format != weight_format && this->is_weight_const_); - if (this->is_weight_const_) { + const memory::desc weight_md = + memory::desc(weight_dims, MklDnnType(), weight_format); + if (IS_WEIGHTS_REORDER_NEEDED(weight_md, matmul_pd, matmul_prim)) { T* cached_weight_data = nullptr; - if (this->IsWeightCacheEmpty(ctx)) { - auto weight_md = - memory::desc(weight_dims, MklDnnType(), weight_format); - this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor, - weight_mkl, weight_md); + + if (this->is_weight_const_) { + if (this->IsWeightCacheEmpty(ctx)) { + this->CacheWeight(ctx, matmul_pd, cached_weight_data, weight_tensor, + weight_mkl, weight_md); + } +#ifdef ENABLE_MKLDNN_V1 + cached_weight_data = this->GetCachedWeight( + ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd)); +#else + cached_weight_data = this->GetCachedWeight( + ctx, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_pd).desc()); +#endif } - cached_weight_data = this->GetCachedWeight(ctx, expected_format); // Cache weight may fail when it gets different format in different // iteration. Fallback to reoder if it happens. - // TODO: Fix this slow path. + // Also do generel reorder if weight isn't const. if (cached_weight_data != nullptr) { weight_data = cached_weight_data; } else { - memory::desc input_md = - memory::desc(weight_dims, MklDnnType(), weight_format); - - //>>>>>>> master - weight_mkl.SetUsrMem(input_md, weight_data); + weight_mkl.SetUsrMem(weight_md, weight_data); weight_mkl.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( matmul_pd.get()->PRIMITIVE_DESC_WEIGHTS, this->cpu_engine_)); weight_data = @@ -210,23 +211,21 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { void ExtendMklDnnMatMulFwdParams(OpKernelContext* ctx, MklDnnMatMulFwdParams& params) { -#ifndef ENABLE_MKLDNN_V1 if (fused_ops_.size() == 2) { string post_op = fused_ops_[1]; if (post_op == "Relu") { - params.post_op_params.push_back({"relu", { 1.0, 0.0, 0.0 }}); + params.post_op_params.push_back({"relu", {1.0, 0.0, 0.0}}); } else if (post_op == "Relu6") { - params.post_op_params.push_back({"relu6", { 1.0, 6.0, 0.0 }}); + params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}}); } else if (post_op == "Elu") { - params.post_op_params.push_back({"elu", { 1.0, 1.0, 0.0 }}); + params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}}); } else { OP_REQUIRES_OK( ctx, errors::InvalidArgument( "Unsupported post-argument in MklFusedMatMul: ", post_op)); } } -#endif // !ENABLE_MKLDNN_V1 } private: diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index e912613392e..ca90a24a1cd 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -41,7 +41,8 @@ struct MklDnnMatMulFwdParams { memory::dims weight_dims; memory::dims bias_dims; memory::dims dst_dims; - MEMORY_FORMAT weight_fmt; + MEMORY_FORMAT src_format; + MEMORY_FORMAT weight_format; string dtypes = string(""); struct PostOpParam { string name; @@ -51,12 +52,14 @@ struct MklDnnMatMulFwdParams { MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims, memory::dims dst_dims, - MEMORY_FORMAT weight_fmt = MEMORY_FORMAT::any) + MEMORY_FORMAT src_format = MEMORY_FORMAT::any, + MEMORY_FORMAT weight_format = MEMORY_FORMAT::any) : src_dims(src_dims), weight_dims(weight_dims), bias_dims(bias_dims), dst_dims(dst_dims), - weight_fmt(weight_fmt) {} + src_format(src_format), + weight_format(weight_format) {} }; // With quantization, input, weight, bias, and output can have different types. @@ -182,15 +185,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { // format. context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims}, MklDnnType(), - MEMORY_FORMAT::any)); + matmul_fwd_params.src_format)); context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, MklDnnType(), -#ifdef ENABLE_MKLDNN_V1 - MEMORY_FORMAT::any)); -#else - matmul_fwd_params.weight_fmt)); -#endif // ENABLE_MKLDNN_V1 + matmul_fwd_params.weight_format)); context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, MklDnnType(), @@ -438,49 +437,61 @@ class MklDnnMatMulOpBase : public OpKernel { // reorder and cache the weight weight.SetUsrMem(weight_md, &weight_tensor); - weight.CheckReorderToOpMem(matmul_fwd_pd.get()->weights_primitive_desc()); + weight.CheckReorderToOpMem(MEMORY_PD_WITHOUT_DATA( + matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS, cpu_engine_)); weight_data = static_cast(weight.GetOpMem().get_data_handle()); Tensor* weight_tensor_ptr = nullptr; + size_t size = matmul_fwd_pd.get()->PRIMITIVE_DESC_WEIGHTS.get_size(); TensorShape weight_tf_shape; - weight_tf_shape.AddDim( - (matmul_fwd_pd.get()->weights_primitive_desc().get_size() / - sizeof(Tweight))); + weight_tf_shape.AddDim(size / sizeof(Tweight)); OP_REQUIRES_OK(context, context->allocate_persistent( DataTypeToEnum::value, weight_tf_shape, &weight_oi_, &weight_tensor_ptr)); void* weight_oi_t_data = weight.GetTensorBuffer(weight_tensor_ptr); - size_t weight_size = weight.GetOpMem().get_primitive_desc().get_size(); - memcpy(weight_oi_t_data, weight_data, weight_size); + memcpy(weight_oi_t_data, weight_data, size); - // cache the memory descriptor +// cache the memory descriptor +#ifdef ENABLE_MKLDNN_V1 + auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd); +#else + auto expected_md = GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc(); +#endif Tensor* weight_md_tensor_ptr = nullptr; TensorShape weight_mkl_format; - weight_mkl_format.AddDim(1); + weight_mkl_format.AddDim(sizeof(expected_md) / sizeof(Tweight)); - OP_REQUIRES_OK(context, context->allocate_persistent( - DT_INT32, weight_mkl_format, &weight_oi_md_, - &weight_md_tensor_ptr)); - weight_md_tensor_ptr->scalar()() = - matmul_fwd_pd.get()->weights_primitive_desc().desc().data.format; + OP_REQUIRES_OK( + context, context->allocate_persistent(DataTypeToEnum::value, + weight_mkl_format, &weight_oi_md_, + &weight_md_tensor_ptr)); + *reinterpret_cast( + weight_md_tensor_ptr->flat().data()) = expected_md; } Tweight* GetCachedWeight(OpKernelContext* context, - const memory::format& weight_mf) + const memory::desc& expected_md) LOCKS_EXCLUDED(mu_) { tf_shared_lock lock(mu_); const Tensor& weight_t = *weight_oi_.AccessTensor(context); const Tensor& weight_md_t = *weight_oi_md_.AccessTensor(context); - // Check if the memory descriptor of the cached weight is same as - // weight_mf. if so use the cached memory, else return NULL - if (weight_md_t.scalar().size() && - weight_md_t.scalar()() == weight_mf) { - return static_cast( - const_cast(weight_t.flat().data())); + // Check if the memory descriptor of the cached weight is same as + // expected_md. if so use the cached memory, else return NULL + if (weight_md_t.flat().size()) { + const memory::desc& stored_md = + *(static_cast(weight_md_t.data())); +#ifdef ENABLE_MKLDNN_V1 + if (stored_md == expected_md) { +#else + if (stored_md.data.format == expected_md.data.format) { +#endif + return static_cast( + const_cast(weight_t.flat().data())); + } } return nullptr; } @@ -527,7 +538,8 @@ void dnnl_gemm_exec(const dnnl::desc& a_md, const dnnl::desc& b_md, dnnl::stream s(cpu_engine); matmul_prim.execute(s, {{DNNL_ARG_SRC, a_memory}, {DNNL_ARG_WEIGHTS, b_memory}, - {DNNL_ARG_DST, c_memory}}); + { DNNL_ARG_DST, + c_memory }}); s.wait(); } diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc index 313ff32e7df..de4974f659e 100644 --- a/tensorflow/core/kernels/mkl_maxpooling_op.cc +++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc @@ -32,7 +32,9 @@ using mkldnn::algorithm; using mkldnn::engine; using mkldnn::error; using mkldnn::memory; +#ifndef ENABLE_MKLDNN_V1 using mkldnn::padding_kind; +#endif using mkldnn::pooling_backward; using mkldnn::pooling_forward; using mkldnn::prop_kind; @@ -136,21 +138,31 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { pooling_prop_kind = prop_kind::forward_inference; else pooling_prop_kind = prop_kind::forward_training; +#ifdef ENABLE_MKLDNN_V1 MklPoolingParams fwdParams( src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, pooling_prop_kind, - static_cast(input_md.data.format)); + static_cast(this->data_format_mkldnn_)); +#else + MklPoolingParams fwdParams( + src_dims, output_dims_mkl_order, filter_dims, strides, padding_left, + padding_right, ALGORITHM::pooling_max, pooling_prop_kind, + static_cast(input_md.data.format)); +#endif pooling_fwd = MklPoolingFwdPrimitiveFactory::Get(fwdParams); - // Allocate output tensor. this->AllocateOutputTensor(context, *(pooling_fwd->GetPoolingFwdPd()), output_dims_mkl_order, this->tensor_format_mkldnn_, &output_tensor); OP_REQUIRES_OK(context, context->status()); +#ifndef ENABLE_MKLDNN_V1 dnn_data_output.SetUsrMem(output_dims_mkl_order, - pooling_fwd->GetDstMemoryFormat(), - output_tensor); - + this->data_format_mkldnn_, output_tensor); +#else + dnn_data_output.SetUsrMem( + GET_DST_DESC_FROM_OP_PD(pooling_fwd->GetPoolingFwdPd()), + output_tensor); +#endif // !ENABLE_MKLDNN_V1 const T* src_data = input_tensor.flat().data(); T* dst_data = output_tensor->flat().data(); @@ -183,7 +195,6 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase { OP_REQUIRES_OK(context, context->status()); T* ws_data = static_cast(dnn_data_wksp.GetOpMem().get_data_handle()); - // Execute pooling op. pooling_fwd->Execute(src_data, dst_data, ws_data); } @@ -285,11 +296,19 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase { : memory::desc(orig_input_dims_mkl_order, MklDnnType(), this->data_format_mkldnn_); +#ifdef ENABLE_MKLDNN_V1 MklPoolingParams bwdParams( orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, strides, padding_left, padding_right, ALGORITHM::pooling_max, prop_kind::forward_training, - static_cast(src_md.data.format)); + static_cast(this->data_format_mkldnn_)); +#else + MklPoolingParams bwdParams( + orig_input_dims_mkl_order, output_dims_mkl_order, filter_dims, + strides, padding_left, padding_right, ALGORITHM::pooling_max, + prop_kind::forward_training, + static_cast(src_md.data.format)); +#endif MklPoolingBwdPrimitive* pooling_bwd = MklPoolingBwdPrimitiveFactory::Get(bwdParams); diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index b90a74c13af..29bcaf5e67c 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -54,17 +54,25 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType(), MEMORY_FORMAT::any)); +#ifndef ENABLE_MKLDNN_V1 // Create a pooling descriptor. context_.fwd_desc.reset(new pooling_forward::desc( fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, fwdParams.padding_left, fwdParams.padding_right, padding_kind::zero)); +#else + context_.fwd_desc.reset(new pooling_forward::desc( + fwdParams.prop_kind, fwdParams.alg_kind, *context_.src_md, + *context_.dst_md, fwdParams.strides, fwdParams.filter_dims, + fwdParams.padding_left, fwdParams.padding_right)); +#endif // !ENABLE_MKLDNN_V1 context_.fwd_pd.reset( new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_)); - #ifndef ENABLE_MKLDNN_V1 context_.dst_fmt = static_cast( context_.fwd_pd.get()->PRIMITIVE_DESC_DST.desc().data.format); +#else + context_.dst_fmt = static_cast(MEMORY_FORMAT::any); #endif // ENABLE_MKLDNN_V1 // Create MKL-DNN internal memory object with dummy data. @@ -72,7 +80,6 @@ void MklPoolingFwdPrimitive::Setup(const MklPoolingParams& fwdParams) { context_.fwd_pd.get()->PRIMITIVE_DESC_SRC, cpu_engine_, DummyData)); context_.dst_mem.reset(new MEMORY_CONSTRUCTOR( context_.fwd_pd.get()->PRIMITIVE_DESC_DST, cpu_engine_, DummyData)); - // For max pooling, need to return workspace (ws) for backward computing. if (fwdParams.alg_kind == ALGORITHM::pooling_max && fwdParams.prop_kind == prop_kind::forward_training) { @@ -160,10 +167,11 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { // Create memory descriptor. context_.diff_src_md.reset(new memory::desc( - {bwdParams.src_dims}, MklDnnType(), memory::format::any)); + {bwdParams.src_dims}, MklDnnType(), MEMORY_FORMAT::any)); context_.diff_dst_md.reset(new memory::desc( {bwdParams.dst_dims}, MklDnnType(), bwdParams.src_format)); +#ifndef ENABLE_MKLDNN_V1 context_.bwd_desc.reset(new pooling_backward::desc( bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left, @@ -175,6 +183,18 @@ void MklPoolingBwdPrimitive::Setup(const MklPoolingParams& bwdParams) { bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left, bwdParams.padding_right, padding_kind::zero)); +#else + context_.bwd_desc.reset(new pooling_backward::desc( + bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md, + bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left, + bwdParams.padding_right)); + // Create a forward primitive, + // which will be used as a hint for creating backward primitive. + context_.fwd_desc.reset(new pooling_forward::desc( + bwdParams.prop_kind, bwdParams.alg_kind, *context_.diff_src_md, + *context_.diff_dst_md, bwdParams.strides, bwdParams.filter_dims, + bwdParams.padding_left, bwdParams.padding_right)); +#endif // !ENABLE_MKLDNN_V1 context_.fwd_pd.reset( new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_)); context_.bwd_pd.reset(new pooling_backward::primitive_desc( diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h index 55f5efb536a..8d18c95a542 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.h +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h @@ -47,13 +47,13 @@ struct MklPoolingParams { memory::dims padding_right; mkldnn::algorithm alg_kind; mkldnn::prop_kind prop_kind; - memory::format src_format; + MEMORY_FORMAT src_format; MklPoolingParams(memory::dims src_dims, memory::dims dst_dims, memory::dims filter_dims, memory::dims strides, memory::dims padding_left, memory::dims padding_right, mkldnn::algorithm alg_kind, mkldnn::prop_kind prop_kind, - memory::format src_format) + MEMORY_FORMAT src_format) : src_dims(src_dims), dst_dims(dst_dims), filter_dims(filter_dims), diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc index 1ee01bfa85c..8e4000d2ba3 100644 --- a/tensorflow/core/kernels/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc @@ -269,11 +269,11 @@ class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase { } #ifdef ENABLE_MKLDNN_V1 weight_data = this->GetCachedWeight( - context, static_cast(weight_mkl_shape.GetTfDataFormat())); + context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd)); #else weight_data = this->GetCachedWeight( - context, matmul_fwd->GetWeightMemoryFormat()); -#endif // ENABLE_MKLDNN_V1 + context, GET_WEIGHTS_DESC_FROM_OP_PD(matmul_fwd_pd).desc()); +#endif is_weight_cached = (weight_data != nullptr); } diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index d20b72d879a..2756882c144 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::algorithm; using mkldnn::eltwise_forward; @@ -164,7 +164,11 @@ class MklEltwiseFwdPrimitive : public MklPrimitive { context_.src_md.reset(new memory::desc(fwdParams.src_md.data)); context_.src_mpd.reset( +#ifdef ENABLE_MKLDNN_V1 + new MEMORY_PRIMITIVE_DESC(*context_.src_md)); +#else new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_)); +#endif // Create an eltwise forward descriptor and primitive descriptor context_.fwd_desc.reset(new eltwise_forward::desc( @@ -397,7 +401,6 @@ class MklEltwiseBwdPrimitive : public MklPrimitive { // Create memory descriptors for eltwise data w/ no specified format context_.src_md.reset(new memory::desc(bwdParams.common_md.data)); context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data)); - context_.src_mpd.reset( new MEMORY_PD_CONSTRUCTOR_2_PARAMS(*context_.src_md, cpu_engine_)); context_.diff_dst_mpd.reset( diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index a2b7491e3fd..258733fc8aa 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -18,7 +18,6 @@ limitations under the License. #ifdef INTEL_MKL #include "mkldnn.hpp" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,6 +26,7 @@ limitations under the License. #include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/util/tensor_format.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" using mkldnn::prop_kind; using mkldnn::softmax_forward; @@ -65,7 +65,7 @@ class MklSoftmaxPrimitive : public MklPrimitive { #ifdef ENABLE_MKLDNN_V1 execute_primitives(context_.fwd_primitives, context_.fwd_stream, - context_.net_args); + context_.fwd_net_args); #else context_.fwd_stream->submit(context_.fwd_primitives); #endif diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 6ec60341321..e46b7436066 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -1259,6 +1259,7 @@ REGISTER_OP("_MklFusedBatchNormV3") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr(GetConvnetDataFormatAttrString()) + .Attr("exponential_avg_factor: float = 1.0") .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape) .Doc( diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 84f25347a86..01bc3e8190c 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -2524,6 +2524,7 @@ REGISTER_OP("_MklFusedBatchNorm") .Attr("T: numbertype") .Attr("epsilon: float = 0.0001") .Attr("data_format: string = 'NHWC'") + .Attr("exponential_avg_factor: float = 1.0") .Attr("is_training: bool = true") .SetShapeFn([](InferenceContext* c) { ShapeHandle x; @@ -2673,6 +2674,7 @@ REGISTER_OP("_MklFusedBatchNormV2") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") .Attr(GetConvnetDataFormatAttrString()) + .Attr("exponential_avg_factor: float = 1.0") .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); diff --git a/tensorflow/core/util/mkl_types.h b/tensorflow/core/util/mkl_types.h index e05bee3cc8a..cdf313d585e 100644 --- a/tensorflow/core/util/mkl_types.h +++ b/tensorflow/core/util/mkl_types.h @@ -23,6 +23,7 @@ namespace tensorflow { #define ADD_MD add_md #define ALGORITHM mkldnn::algorithm #define ALGORITHM_UNDEF ALGORITHM::undef +#define BN_FLAGS mkldnn::normalization_flags #define CPU_STREAM(engine) stream(engine) #define DATA_WITH_ENGINE(data, engine) data, engine #define DST_MD dst_md @@ -41,6 +42,8 @@ namespace tensorflow { GET_MEMORY_DESC_FROM_MEM_PTR(mem_ptr) #define GET_MEMORY_SIZE_FROM_MD(md, engine) md.get_size() #define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd->src_desc() +#define GET_DST_DESC_FROM_OP_PD(op_pd) op_pd->dst_desc() +#define GET_BIAS_DESC_FROM_OP_PD(op_pd) op_pd->bias_desc() #define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) op_pd->diff_dst_desc() #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) op_pd->workspace_desc() #define GET_TENSOR_FORMAT(fmt) MklTensorFormatToMklDnnDataFormat(fmt) @@ -112,6 +115,7 @@ namespace tensorflow { #define TENSOR_FORMAT_NHWC MKL_TENSOR_FORMAT_NHWC #define TENSOR_MAX_DIMS MKLDNN_MAX_NDIMS #define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemDesc() +#define BN_FLAGS mkldnn::normalization_flags #else @@ -136,6 +140,8 @@ namespace tensorflow { #define GET_MEMORY_SIZE_FROM_MD(md, engine) \ memory::primitive_desc(md, engine).get_size() #define GET_SRC_DESC_FROM_OP_PD(op_pd) op_pd.get()->src_primitive_desc() +#define GET_DST_DESC_FROM_OP_PD(op_pd) op_pd.get()->dst_primitive_desc() +#define GET_BIAS_DESC_FROM_OP_PD(op_pd) op_pd.get()->bias_primitive_desc() #define GET_DIFF_DST_DESC_FROM_OP_PD(op_pd) \ op_pd.get()->diff_dst_primitive_desc() #define GET_WORKSPACE_DESC_FROM_OP_PD(op_pd) \ @@ -210,6 +216,7 @@ namespace tensorflow { #define TENSOR_FORMAT TensorFormat #define TENSOR_FORMAT_NHWC FORMAT_NHWC #define GET_USR_MEM_PRIM_DESC(src) src.GetUsrMemPrimDesc() +#define BN_FLAGS mkldnn #endif // ENABLE_MKLDNN_V1 } // namespace tensorflow diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index b5f6b2a705e..fb4aba3e89e 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -42,7 +42,9 @@ limitations under the License. using mkldnn::engine; using mkldnn::memory; +#ifndef ENABLE_MKLDNN_V1 using mkldnn::padding_kind; +#endif using mkldnn::primitive; using mkldnn::reorder; using mkldnn::stream; @@ -1224,10 +1226,12 @@ inline memory::dims CalculateTFStrides(const memory::dims& dims_tf_order) { return strides; } +#ifndef ENABLE_MKLDNN_V1 inline padding_kind TFPaddingToMklDnnPadding(Padding pad) { // MKL-DNN only supports zero padding. return padding_kind::zero; } +#endif /// Helper function to create memory descriptor in Blocked format /// @@ -2157,7 +2161,7 @@ inline bool IsConv1x1StrideNot1(memory::dims filter_dims, } #ifdef ENABLE_MKLDNN_V1 -void execute_primitives( +inline void execute_primitives( std::vector& primitives, std::shared_ptr stream, std::vector>& net_args) { DCHECK_EQ(primitives.size(), net_args.size()); diff --git a/tensorflow/core/util/mkl_util_test.cc b/tensorflow/core/util/mkl_util_test.cc index eb8e15dfee2..8513fa986f9 100644 --- a/tensorflow/core/util/mkl_util_test.cc +++ b/tensorflow/core/util/mkl_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #ifdef INTEL_MKL #include "tensorflow/core/util/mkl_util.h" +#include "tensorflow/core/util/mkl_types.h" #include "tensorflow/core/platform/test.h" @@ -23,7 +24,7 @@ namespace tensorflow { namespace { TEST(MklUtilTest, MklDnnTfShape) { - auto cpu_engine = engine(engine::cpu, 0); + auto cpu_engine = engine(ENGINE_CPU, 0); MklDnnData a(&cpu_engine); const int N = 1, C = 2, H = 3, W = 4; @@ -31,7 +32,7 @@ TEST(MklUtilTest, MklDnnTfShape) { MklDnnShape a_mkldnn_shape; a_mkldnn_shape.SetMklTensor(true); // Create TF layout in NCHW. - a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, memory::format::nchw); + a_mkldnn_shape.SetTfLayout(a_dims.size(), a_dims, MKL_TENSOR_FORMAT_NCHW); TensorShape a_tf_shape_nchw({N, C, H, W}); TensorShape a_tf_shape_nhwc({N, H, W, C}); TensorShape a_mkldnn_tf_shape = a_mkldnn_shape.GetTfShape(); @@ -43,7 +44,7 @@ TEST(MklUtilTest, MklDnnTfShape) { MklDnnShape b_mkldnn_shape; b_mkldnn_shape.SetMklTensor(true); // Create TF layout in NHWC. - b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, memory::format::nhwc); + b_mkldnn_shape.SetTfLayout(b_dims.size(), b_dims, MKL_TENSOR_FORMAT_NHWC); TensorShape b_tf_shape_nhwc({N, H, W, C}); TensorShape b_tf_shape_nchw({N, C, H, W}); TensorShape b_mkldnn_tf_shape = b_mkldnn_shape.GetTfShape(); @@ -55,7 +56,7 @@ TEST(MklUtilTest, MklDnnTfShape) { TEST(MklUtilTest, MklDnnBlockedFormatTest) { // Let's create 2D tensor of shape {3, 4} with 3 being innermost dimension // first (case 1) and then it being outermost dimension (case 2). - auto cpu_engine = engine(engine::cpu, 0); + auto cpu_engine = engine(ENGINE_CPU, 0); // Setting for case 1 MklDnnData a(&cpu_engine); @@ -67,7 +68,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) { EXPECT_EQ(a_md1.data.ndims, 2); EXPECT_EQ(a_md1.data.dims[0], 3); EXPECT_EQ(a_md1.data.dims[1], 4); +#ifndef ENABLE_MKLDNN_V1 EXPECT_EQ(a_md1.data.format, mkldnn_blocked); +#endif // !ENABLE_MKLDNN_V1 // Setting for case 2 MklDnnData b(&cpu_engine); @@ -79,7 +82,9 @@ TEST(MklUtilTest, MklDnnBlockedFormatTest) { EXPECT_EQ(b_md2.data.ndims, 2); EXPECT_EQ(b_md2.data.dims[0], 3); EXPECT_EQ(b_md2.data.dims[1], 4); +#ifndef ENABLE_MKLDNN_V1 EXPECT_EQ(b_md2.data.format, mkldnn_blocked); +#endif // !ENABLE_MKLDNN_V1 } TEST(MklUtilTest, LRUCacheTest) { diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index f596e7c264c..8b8248399d5 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -174,12 +174,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "mkl_dnn_v1", - build_file = clean_dep("//third_party/mkl_dnn:mkldnn.BUILD"), - sha256 = "fcc2d951f7170eade0cfdd0d8d1d58e3e7785bd326bca6555f3722f8cba71811", - strip_prefix = "mkl-dnn-1.0-pc2", + build_file = clean_dep("//third_party/mkl_dnn:mkldnn_v1.BUILD"), + sha256 = "27fd9da9720c452852f1226581e7914efcf74e1ff898468fdcbe1813528831ba", + strip_prefix = "mkl-dnn-1.0", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz", - "https://github.com/intel/mkl-dnn/archive/v1.0-pc2.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/intel/mkl-dnn/archive/v1.0.tar.gz", + "https://github.com/intel/mkl-dnn/archive/v1.0.tar.gz", ], ) diff --git a/third_party/mkl_dnn/mkldnn_v1.BUILD b/third_party/mkl_dnn/mkldnn_v1.BUILD new file mode 100644 index 00000000000..517abca3ebb --- /dev/null +++ b/third_party/mkl_dnn/mkldnn_v1.BUILD @@ -0,0 +1,135 @@ +exports_files(["LICENSE"]) + +load( + "@org_tensorflow//third_party/mkl_dnn:build_defs.bzl", + "if_mkl_open_source_only", + "if_mkl_v1_open_source_only", +) +load( + "@org_tensorflow//third_party:common.bzl", + "template_rule", +) + +config_setting( + name = "clang_linux_x86_64", + values = { + "cpu": "k8", + "define": "using_clang=true", + }, +) + +template_rule( + name = "mkldnn_config_h", + src = "include/mkldnn_config.h.in", + out = "include/mkldnn_config.h", + substitutions = { + "#cmakedefine MKLDNN_CPU_RUNTIME MKLDNN_RUNTIME_${MKLDNN_CPU_RUNTIME_CURRENT}": "#define MKLDNN_CPU_RUNTIME MKLDNN_RUNTIME_OMP", + "#cmakedefine MKLDNN_GPU_RUNTIME MKLDNN_RUNTIME_${MKLDNN_GPU_RUNTIME}": "#define MKLDNN_GPU_RUNTIME MKLDNN_RUNTIME_NONE", + }, +) + +# Create the file mkldnn_version.h with MKL-DNN version numbers. +# Currently, the version numbers are hard coded here. If MKL-DNN is upgraded then +# the version numbers have to be updated manually. The version numbers can be +# obtained from the PROJECT_VERSION settings in CMakeLists.txt. The variable is +# set to "version_major.version_minor.version_patch". The git hash version can +# be set to NA. +# TODO(agramesh1) Automatically get the version numbers from CMakeLists.txt. + +template_rule( + name = "mkldnn_version_h", + src = "include/mkldnn_version.h.in", + out = "include/mkldnn_version.h", + substitutions = { + "@MKLDNN_VERSION_MAJOR@": "1", + "@MKLDNN_VERSION_MINOR@": "0", + "@MKLDNN_VERSION_PATCH@": "0", + "@MKLDNN_VERSION_HASH@": "N/A", + }, +) + +cc_library( + name = "mkl_dnn", + srcs = glob([ + "src/common/*.cpp", + "src/common/*.hpp", + "src/cpu/*.cpp", + "src/cpu/*.hpp", + "src/cpu/**/*.cpp", + "src/cpu/**/*.hpp", + "src/cpu/xbyak/*.h", + ]) + if_mkl_v1_open_source_only([ + ":mkldnn_config_h", + ]) + [":mkldnn_version_h"], + hdrs = glob(["include/*"]), + copts = [ + "-fexceptions", + "-DUSE_MKL", + "-DUSE_CBLAS", + ] + if_mkl_open_source_only([ + "-UUSE_MKL", + "-UUSE_CBLAS", + ]) + if_mkl_v1_open_source_only([ + "-UUSE_MKL", + "-UUSE_CBLAS", + ]) + select({ + "@org_tensorflow//tensorflow:linux_x86_64": [ + "-fopenmp", # only works with gcc + ], + # TODO(ibiryukov): enable openmp with clang by including libomp as a + # dependency. + ":clang_linux_x86_64": [], + "//conditions:default": [], + }), + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + "src/cpu/xbyak", + ], + visibility = ["//visibility:public"], + deps = select({ + "@org_tensorflow//tensorflow:linux_x86_64": [ + "@mkl_linux//:mkl_headers", + "@mkl_linux//:mkl_libs_linux", + ], + "@org_tensorflow//tensorflow:macos": [ + "@mkl_darwin//:mkl_headers", + "@mkl_darwin//:mkl_libs_darwin", + ], + "@org_tensorflow//tensorflow:windows": [ + "@mkl_windows//:mkl_headers", + "@mkl_windows//:mkl_libs_windows", + ], + "//conditions:default": [], + }), +) + +cc_library( + name = "mkldnn_single_threaded", + srcs = glob([ + "src/common/*.cpp", + "src/common/*.hpp", + "src/cpu/*.cpp", + "src/cpu/*.hpp", + "src/cpu/**/*.cpp", + "src/cpu/**/*.hpp", + "src/cpu/xbyak/*.h", + ]) + [":mkldnn_config_h"], + hdrs = glob(["include/*"]), + copts = [ + "-fexceptions", + "-DMKLDNN_THR=MKLDNN_THR_SEQ", # Disables threading. + ], + includes = [ + "include", + "src", + "src/common", + "src/cpu", + "src/cpu/gemm", + "src/cpu/xbyak", + ], + visibility = ["//visibility:public"], +)