From edfb6806fe35a0892a0e60a0176a1fb26fa2d4ba Mon Sep 17 00:00:00 2001 From: TengLu Date: Tue, 20 Aug 2019 10:31:46 +0800 Subject: [PATCH 1/4] Enable FP32 FusedMatMul for MKL-DNN. --- tensorflow/core/graph/mkl_layout_pass.cc | 23 ++ tensorflow/core/graph/mkl_layout_pass_test.cc | 47 +++ .../core/grappler/optimizers/remapper.cc | 21 +- tensorflow/core/kernels/BUILD | 8 +- tensorflow/core/kernels/mkl_fused_ops_test.cc | 196 +++++++-- .../core/kernels/mkl_matmul_op_fused.cc | 194 +++++++++ .../core/kernels/mkl_matmul_ops_common.h | 371 ++++++++++++++++++ tensorflow/core/kernels/mkl_qmatmul_op.cc | 353 +---------------- tensorflow/core/ops/mkl_nn_ops.cc | 26 ++ tensorflow/core/util/mkl_util.h | 10 +- 10 files changed, 849 insertions(+), 400 deletions(-) create mode 100644 tensorflow/core/kernels/mkl_matmul_op_fused.cc create mode 100644 tensorflow/core/kernels/mkl_matmul_ops_common.h diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 5ba4fc6cfbe..39c8b781085 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -273,6 +273,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.fused_batch_norm_v3 = "FusedBatchNormV3"; csinfo_.fused_batch_norm_grad_v3 = "FusedBatchNormGradV3"; csinfo_.fused_conv2d = "_FusedConv2D"; + csinfo_.fused_matmul = "_FusedMatMul"; csinfo_.identity = "Identity"; csinfo_.leakyrelu = "LeakyRelu"; csinfo_.leakyrelu_grad = "LeakyReluGrad"; @@ -294,6 +295,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.mkl_depthwise_conv2d_grad_filter = "_MklDepthwiseConv2dNativeBackpropFilter"; csinfo_.mkl_fused_conv2d = "_MklFusedConv2D"; + csinfo_.mkl_fused_matmul = "_MklFusedMatMul"; csinfo_.mkl_pad_with_conv2d = "_MklPadWithConv2D"; csinfo_.mkl_pad_with_fused_conv2d = "_MklPadWithFusedConv2D"; csinfo_.pad = "Pad"; @@ -472,6 +474,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.fused_conv2d, csinfo_.mkl_fused_conv2d, CopyAttrsFusedConv2D, FusedConv2DRewrite, kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.fused_matmul, csinfo_.mkl_fused_matmul, + CopyAttrsAll, FusedMatMulRewrite}); rinfo_.push_back({csinfo_.identity, mkl_op_registry::GetMklOpName(csinfo_.identity), CopyAttrsAll, RewriteIfAtleastOneMklInput, @@ -897,6 +901,7 @@ rinfo_.push_back({csinfo_.tanh_grad, string fused_batch_norm_v3; string fused_batch_norm_grad_v3; string fused_conv2d; + string fused_matmul; string identity; string leakyrelu; string leakyrelu_grad; @@ -916,6 +921,7 @@ rinfo_.push_back({csinfo_.tanh_grad, string mkl_depthwise_conv2d_grad_input; string mkl_depthwise_conv2d_grad_filter; string mkl_fused_conv2d; + string mkl_fused_matmul; string mkl_pad_with_conv2d; string mkl_pad_with_fused_conv2d; string mul; @@ -1445,6 +1451,22 @@ rinfo_.push_back({csinfo_.tanh_grad, return false; } + // Rewrite relu for _FusedMatMul. + // @return - true (no transpose attribute for input 1 and only has 1 post op); + // false otherwise. + static bool FusedMatMulRewrite(const Node* n) { + bool trans_a; + std::vector fused_ops; + + // Do not rewrite with transpose attribute because reorder has performance + // impact. + TF_CHECK_OK(GetNodeAttr(n->def(), "transpose_a", &trans_a)); + // Do not rewrite with more than 1 post op because MKL-DNN doesn't support. + TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops)); + + return !trans_a && fused_ops.size() == 1; + } + // Check if we are performing pooling on depth or batch. If it is, then we // do not rewrite MaxPool node to Mkl version. // @return - true (if it is not a depth/batch wise pooling case); @@ -3528,6 +3550,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const { n->type_string() != csinfo_.pad_with_fused_conv2d && n->type_string() != csinfo_.conv2d_grad_filter_with_bias && n->type_string() != csinfo_.fused_conv2d && + n->type_string() != csinfo_.fused_matmul && !mkl_op_registry::IsMklOp(mkl_op_registry::GetMklOpName(n->type_string()), T)) { return nullptr; diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index b69a30e8274..9f4e2105644 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -1819,6 +1819,53 @@ TEST_F(MklLayoutPassTest, NodeRewrite_FusedConv2D_Negative2) { "D(_FusedConv2D);E(Zeta)|A->D;B->D:1;C->D:2;C->E:1;D->E"); } +// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests +TEST_F(MklLayoutPassTest, NodeRewrite_FusedMatMul_Postive) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: '_FusedMatMul'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'transpose_a' value { b: false } }" + " attr { key: 'transpose_b' value { b: false } }" + " attr { key: 'num_args' value { i: 1 } }" + " attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" + " attr { key: 'epsilon' value { f: 0.001 }}" + " input: ['A', 'B', 'C']}" + "node { name: 'Z' op: 'Zeta'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'C']}"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_MklFusedMatMul);DMT/_0(Const);" + "DMT/_1(Const);DMT/_2(Const);Z(Zeta)" + "|A->D;A:control->DMT/_0:control;A:control->DMT/_1:control;" + "A:control->DMT/_2:control;B->D:1;C->D:2;C->Z:1;D->Z;DMT/_0->D:3;" + "DMT/_1->D:4;DMT/_2->D:5"); +} + +// Test set: _FusedMatMul -> MklFusedMatMul rewrite tests +TEST_F(MklLayoutPassTest, NodeRewrite_FusedMatMul_Negative) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'C' op: 'Input'}" + "node { name: 'D' op: '_FusedMatMul'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'transpose_a' value { b: true } }" + " attr { key: 'transpose_b' value { b: false } }" + " attr { key: 'num_args' value { i: 1 } }" + " attr { key: 'fused_ops' value { list: {s: 'BiasAdd'} } }" + " attr { key: 'epsilon' value { f: 0.001 }}" + " input: ['A', 'B', 'C']}" + "node { name: 'Z' op: 'Zeta'" + " attr {key: 'T' value { type: DT_FLOAT } }" + " input: ['D', 'C']}"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);C(Input);D(_FusedMatMul);Z(Zeta)" + "|A->D;B->D:1;C->D:2;C->Z:1;D->Z"); +} + // Merge test for PadWithFusedConv2D Op with BiasAdd fusion // padding is VALID type // A = input(image), B = input(paddings), C = Pad(A, B) = input of conv2D, diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 766e8a1056e..ab04f50494c 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -272,13 +272,7 @@ bool IsGpuCompatibleConv2D(const NodeDef* conv2d) { bool IsCpuCompatibleMatMul(const NodeDef* matmul) { DCHECK(IsMatMul(*matmul)) << "Expected MatMul op"; -#ifndef INTEL_MKL - // Temporarily disable Matmul fusions if MKL is enabled. - // TODO(Intel) renable Matmul fusions when enabled by MKL DNN. return NodeIsOnCpu(matmul) && IsCpuCompatibleDataType(matmul); -#else - return false; -#endif // !INTEL_MKL } // Checks if we can rewrite a pattern to the `_Fused{Conv2D,MatMul}` on CPU. @@ -1221,14 +1215,13 @@ Status AddFusedBatchNormExNode(RemapperContext* ctx, const NodeDef& activation = graph->node(matched.activation); VLOG(2) << "Fuse " << activation.op() << " with FusedBatchNorm:" - << " activation=" << activation.name() << " side_input=" - << (matched.side_input != kMissingIndex - ? graph->node(matched.side_input).name() - : "") - << " invalidated=" - << (matched.invalidated != kMissingIndex - ? graph->node(matched.invalidated).name() - : "") + << " activation=" << activation.name() + << " side_input=" << (matched.side_input != kMissingIndex + ? graph->node(matched.side_input).name() + : "") + << " invalidated=" << (matched.invalidated != kMissingIndex + ? graph->node(matched.invalidated).name() + : "") << " fused_batch_norm=" << fused_batch_norm.name(); // Replace FusedBatchNorm with _FusedBatchNormEx + + . diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index e1e23eea133..12e7a65efce 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3803,7 +3803,11 @@ tf_kernel_library( tf_mkl_kernel_library( name = "mkl_matmul_op", - srcs = ["mkl_matmul_op.cc"], + srcs = [ + "mkl_matmul_op.cc", + "mkl_matmul_op_fused.cc" + ], + hdrs = ["mkl_matmul_ops_common.h"], deps = MATH_DEPS + mkl_deps(), ) @@ -7514,6 +7518,7 @@ tf_mkl_kernel_library( name = "mkl_qmatmul_op", srcs = ["mkl_qmatmul_op.cc"], hdrs = [ + "mkl_matmul_ops_common.h", "mkl_quantized_conv_ops.h", "no_op.h", ], @@ -7828,6 +7833,7 @@ tf_cc_test_mkl( ":conv_ops", ":image", ":mkl_conv_op", + ":mkl_matmul_op", ":mkl_tfconv_op", ":ops_testutil", ":ops_util", diff --git a/tensorflow/core/kernels/mkl_fused_ops_test.cc b/tensorflow/core/kernels/mkl_fused_ops_test.cc index 30bc16931d7..1a612612513 100644 --- a/tensorflow/core/kernels/mkl_fused_ops_test.cc +++ b/tensorflow/core/kernels/mkl_fused_ops_test.cc @@ -152,6 +152,33 @@ class CommonTestUtilities : public OpsTestBase { test::ExpectClose(conv_2d, fused_conv_2d, 1e-5); } + + static void VerifyFusedMatrixClose(int depth, int batch, int weight_count, + const std::vector& fused_ops, + const FusedGraphRunner& run_default, + const FusedGraphRunner& run_fused) { + DataType dtype = DataTypeToEnum::v(); + + Tensor input(dtype, {batch, depth}); + input.flat() = input.flat().setRandom(); + + Tensor weight(dtype, {depth, weight_count}); + weight.flat() = weight.flat().setRandom(); + + Tensor bias(dtype, {weight_count}); + bias.flat() = bias.flat().setRandom(); + + Tensor output; + Tensor fused_output; + + run_default(input, weight, bias, fused_ops, &output); + run_fused(input, weight, bias, fused_ops, &fused_output); + + ASSERT_EQ(output.dtype(), fused_output.dtype()); + ASSERT_EQ(output.shape(), fused_output.shape()); + + test::ExpectClose(output, fused_output, 1e-5); + } }; // Testing MKL's fused convolution ops @@ -263,25 +290,24 @@ class MklFusedConv2DOpTest : public OpsTestBase { int depth = kDepth, int image_width = kImageWidth, int image_height = kImageHeight, int image_batch_count = kImageBatchCount) { - const FusedGraphRunner run_default = - [this](const Tensor& input_data, const Tensor& filter_data, - const Tensor& bias_data, const std::vector& fused_ops, - Tensor* out) { - RunConv2DUnfused(input_data, filter_data, bias_data, fused_ops, out); - }; + const FusedGraphRunner run_default = [this]( + const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, const std::vector& fused_ops, + Tensor* out) { + RunConv2DUnfused(input_data, filter_data, bias_data, fused_ops, out); + }; - const FusedGraphRunner run_fused = - [this](const Tensor& input_data, const Tensor& filter_data, - const Tensor& bias_data, const std::vector& fused_ops, - Tensor* out) { - std::vector fused_input = {bias_data}; - if (std::find(fused_ops.begin(), fused_ops.end(), "Add") != - fused_ops.end()) { - fused_input.push_back(input_data); - } - RunMklFusedConv2DOp(input_data, filter_data, fused_input, fused_ops, - out); - }; + const FusedGraphRunner run_fused = [this]( + const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, const std::vector& fused_ops, + Tensor* out) { + std::vector fused_input = {bias_data}; + if (std::find(fused_ops.begin(), fused_ops.end(), "Add") != + fused_ops.end()) { + fused_input.push_back(input_data); + } + RunMklFusedConv2DOp(input_data, filter_data, fused_input, fused_ops, out); + }; CommonTestUtilities::VerifyFusedTensorsClose( depth, image_width, image_height, image_batch_count, filter_size, @@ -578,6 +604,97 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) { Run(DT_FLOAT, image, filter, expected, true); } +// Testing fusion of MatMul and BiasAdd +template +class MklFusedMatMulOpTest : public OpsTestBase { + protected: + void VerifyFusedMatMul(const int kBatch, const int kInputChannel, + const int kOutputChannel) { + const FusedGraphRunner run_default = [this]( + const Tensor& input, const Tensor& weight, const Tensor& bias, + const std::vector& fused_ops, Tensor* output) { + auto root = tensorflow::Scope::NewRootScope(); + auto input_op = + ops::Const(root.WithOpName("input"), Input::Initializer(input)); + Output next_op = ops::MatMul( + root.WithOpName("matmul"), input_op, + ops::Const(root.WithOpName("weight"), Input::Initializer(weight))); + + string last_op = ""; + if (std::find(fused_ops.begin(), fused_ops.end(), "BiasAdd") != + fused_ops.end()) { + last_op = "with_bias"; + next_op = ops::BiasAdd( + root.WithOpName(last_op), next_op, + ops::Const(root.WithOpName("bias"), Input::Initializer(bias))); + } + + CommonTestUtilities::RunAndFetch(root, last_op, output); + }; + + const FusedGraphRunner run_fused = [this]( + const Tensor& input, const Tensor& weight, const Tensor& bias, + const std::vector& fused_ops, Tensor* output) { + DataType dtype = DataTypeToEnum::v(); + const int num_args = fused_ops.size(); + + TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") + .Input(FakeInput(dtype)) + .Input(FakeInput(dtype)) + .Input(FakeInput(num_args, dtype)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(DT_UINT8)) + .Input(FakeInput(num_args, DT_UINT8)) + .Attr("T", dtype) + .Attr("transpose_a", false) + .Attr("transpose_b", false) + .Attr("num_args", num_args) + .Attr("fused_ops", fused_ops) + .Attr("epsilon", 0.0001) + .Attr("_kernel", "MklLayoutDependentOp") + .Finalize(node_def())); + + TF_EXPECT_OK(InitOp()); + + AddInputFromArray(input.shape(), input.flat()); + AddInputFromArray(weight.shape(), weight.flat()); + AddInputFromArray(bias.shape(), bias.flat()); + // Add MKL meta input for input, filter and bias. + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + AddInputFromArray(dummy_shape, dummy_tensor); + + TF_ASSERT_OK(RunOpKernel()); + + const Tensor& output_tensor = *GetOutput(0); + const Tensor& output_meta_tensor = *GetOutput(1); + CommonTestUtilities test_util; + test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, + output); + }; + + CommonTestUtilities::VerifyFusedMatrixClose(kInputChannel, kBatch, + kOutputChannel, {"BiasAdd"}, + run_default, run_fused); + } +}; + +TYPED_TEST_CASE_P(MklFusedMatMulOpTest); + +TYPED_TEST_P(MklFusedMatMulOpTest, BasicTest) { + const int batch = 3; + const int input_channel = 4; + const int output_channel = 5; + + this->VerifyFusedMatMul(batch, input_channel, output_channel); +} + +REGISTER_TYPED_TEST_CASE_P(MklFusedMatMulOpTest, BasicTest); + +using MklFusedMatMulDataTypes = ::testing::Types; +INSTANTIATE_TYPED_TEST_CASE_P(Test, MklFusedMatMulOpTest, + MklFusedMatMulDataTypes); + // Testing fusion of pad and fusedconv2d template class MklPadWithFusedConv2DOpTest : public OpsTestBase { @@ -597,19 +714,18 @@ class MklPadWithFusedConv2DOpTest : public OpsTestBase { int image_width = kImageWidth, int image_height = kImageHeight, int image_batch_count = kImageBatchCount) { - const BiasAddGraphRunner run_default = [this](const Tensor& input_data, - const Tensor& filter_data, - const Tensor& bias_data, - Tensor* out) { + const BiasAddGraphRunner run_default = [this]( + const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { RunMklPadWithFusedConv2DAndBias(input_data, filter_data, bias_data, out); }; - const BiasAddGraphRunner run_fused = - [this](const Tensor& input_data, const Tensor& filter_data, - const Tensor& bias_data, Tensor* out) { - RunMklFusedConv2DWithPadOp(input_data, filter_data, {bias_data}, - {"BiasAdd"}, out); - }; + const BiasAddGraphRunner run_fused = [this]( + const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunMklFusedConv2DWithPadOp(input_data, filter_data, {bias_data}, + {"BiasAdd"}, out); + }; CommonTestUtilities::VerifyBiasAddTensorsClose( depth, image_width, image_height, image_batch_count, filter_size, @@ -622,19 +738,19 @@ class MklPadWithFusedConv2DOpTest : public OpsTestBase { int filter_size, int filter_count, int depth = kDepth, int image_width = kImageWidth, int image_height = kImageHeight, int image_batch_count = kImageBatchCount) { - const BiasAddGraphRunner run_default = - [this](const Tensor& input_data, const Tensor& filter_data, - const Tensor& bias_data, Tensor* out) { - RunMklPadWithFusedConv2DAndBiasRelu(input_data, filter_data, - bias_data, out); - }; + const BiasAddGraphRunner run_default = [this]( + const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunMklPadWithFusedConv2DAndBiasRelu(input_data, filter_data, bias_data, + out); + }; - const BiasAddGraphRunner run_fused = - [this](const Tensor& input_data, const Tensor& filter_data, - const Tensor& bias_data, Tensor* out) { - RunMklFusedConv2DWithPadOp(input_data, filter_data, {bias_data}, - {"BiasAdd", "Relu"}, out); - }; + const BiasAddGraphRunner run_fused = [this]( + const Tensor& input_data, const Tensor& filter_data, + const Tensor& bias_data, Tensor* out) { + RunMklFusedConv2DWithPadOp(input_data, filter_data, {bias_data}, + {"BiasAdd", "Relu"}, out); + }; CommonTestUtilities::VerifyBiasAddTensorsClose( depth, image_width, image_height, image_batch_count, filter_size, diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc new file mode 100644 index 00000000000..7baed255241 --- /dev/null +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -0,0 +1,194 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/math_ops.cc. + +// This file uses MKL-DNN InnerProduct for acceleration of TF Matrix-Matrix +// Multiplication (MatMul) with bias (BiasAdd) operations. +#ifdef INTEL_MKL + +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/mkl_matmul_ops_common.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +// Fuse Operation +template +class MklFusedMatMulOp : public MklDnnMatMulOpBase { + public: + explicit MklFusedMatMulOp(OpKernelConstruction* ctx) + : MklDnnMatMulOpBase(ctx) { + std::vector fused_ops; + + OP_REQUIRES_OK(ctx, ctx->GetAttr("fused_ops", &fused_ops)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); + + OP_REQUIRES(ctx, fused_ops.size() == 1, + errors::InvalidArgument( + "MklFusedMatMul must have only one argument: bias.")); + OP_REQUIRES( + ctx, transpose_a_ == false, + errors::InvalidArgument("In[0] of MklMatMul can't be transposed.")); + } + + void Compute(OpKernelContext* ctx) override { + // FusedMatMul has 3 inputs: src, weights, bias + const Tensor& src_tensor = ctx->input(this->kInputIndexSrc); + const Tensor& weight_tensor = ctx->input(this->kInputIndexWeight); + const Tensor& bias_tensor = MklGetInput(ctx, this->kInputIndexBias); + + MklDnnShape src_dnn_shape; + MklDnnShape weight_dnn_shape; + GetMklShape(ctx, this->kInputIndexSrc, &src_dnn_shape); + GetMklShape(ctx, this->kInputIndexWeight, &weight_dnn_shape); + + // Get shapes of input tensors + auto src_tf_shape = src_dnn_shape.IsMklTensor() ? src_dnn_shape.GetTfShape() + : src_tensor.shape(); + auto weight_tf_shape = weight_dnn_shape.IsMklTensor() + ? weight_dnn_shape.GetTfShape() + : weight_tensor.shape(); + + // Check the constraint of input matrix and bias + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(src_tf_shape), + errors::InvalidArgument("In[0] is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(weight_tf_shape), + errors::InvalidArgument("In[1] is not a matrix")); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(bias_tensor.shape()), + errors::InvalidArgument("Biases must be 1D")); + + // Expression: [batch, k] * [k, channel] + [channel] = [batch, channel] + // + // Get dimension size of each matrix, dim_pair[] is the location of k + // in the inputs, we have constraint that k of the two inputs are + // the same + const int dim_pair[] = {transpose_a_ ? 0 : 1, transpose_b_ ? 1 : 0}; + const int batch = src_tf_shape.dim_size(1 - dim_pair[0]); + const int k = src_tf_shape.dim_size(dim_pair[0]); + const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); + + OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]), + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + src_tf_shape.DebugString(), ", In[1]: ", + weight_tf_shape.DebugString())); + OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel, + errors::InvalidArgument( + "Must provide as many biases as the channel size: ", + bias_tensor.shape().DebugString(), " vs. ", channel)); + + // Create primitive for InnerProduct, it has such desc: + // s[batch, k] * w[channel, k] + b[channel] = dst[batch, channel] + // [n, c] * [oc, ic] + [x] = [n, c] + // + // For weights, dimensions need to be specified as [channel*k]. + memory::dims src_dims = memory::dims({batch, k}); + // In order to satisfy the primitive, reverse the dims of weights + // from [k, channel] to [channel, k]. + memory::dims weight_dims = memory::dims({channel, k}); + memory::dims bias_dims = memory::dims({channel}); + memory::dims dst_dims = memory::dims({batch, channel}); + memory::format weight_format = + transpose_b_ ? memory::format::oi : memory::format::io; + + MklDnnMatMulFwdParams matmul_params(src_dims, weight_dims, bias_dims, + dst_dims, weight_format); + MklDnnMatMulFwdPrimitive* matmul_prim = + MklDnnMatMulFwdPrimitiveFactory::Get(matmul_params, 0); + + // Allocate output tensor. + Tensor* dst_tensor = nullptr; + std::shared_ptr matmul_pd = + matmul_prim->GetPrimitiveDesc(); + + if (src_dnn_shape.IsMklTensor() && weight_dnn_shape.IsMklTensor()) { + this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, memory::format::nc, + &dst_tensor); + } else { + TensorShape dst_tensor_shape({batch, channel}); + MklDnnShape dst_dnn_shape; + dst_dnn_shape.SetMklTensor(false); + AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape, + dst_dnn_shape); + } + + // if there's nothing to compute, just return. + if (batch == 0 || channel == 0) { + return; + } + + try { + // Prepare the input and output for primitive. + T* src_data = const_cast(src_tensor.flat().data()); + T* weight_data = const_cast(weight_tensor.flat().data()); + T* bias_data = const_cast(bias_tensor.flat().data()); + T* dst_data = const_cast(dst_tensor->flat().data()); + + // Any input is MKL format, reorder it if necessary. + MklDnnData src_dnn(&(this->cpu_engine_)); + MklDnnData weight_dnn(&(this->cpu_engine_)); + + if (src_dnn_shape.IsMklTensor()) { + memory::desc input_md = src_dnn_shape.GetMklLayout(); + + if (input_md.data.format != memory::format::nc) { + src_dnn.SetUsrMem(input_md, src_data); + src_dnn.CheckReorderToOpMem(matmul_pd.get()->src_primitive_desc()); + src_data = reinterpret_cast(src_dnn.GetOpMem().get_data_handle()); + } + } + + if (weight_dnn_shape.IsMklTensor()) { + memory::desc input_md = weight_dnn_shape.GetMklLayout(); + + if (input_md.data.format != weight_format) { + weight_dnn.SetUsrMem(input_md, weight_data); + weight_dnn.CheckReorderToOpMem( + matmul_pd.get()->weights_primitive_desc()); + weight_data = + reinterpret_cast(weight_dnn.GetOpMem().get_data_handle()); + } + } + + matmul_prim->Execute(src_data, weight_data, bias_data, dst_data); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + ctx, errors::Aborted("Operation received an exception:", error_msg)); + } + } + + private: + bool transpose_a_; + bool transpose_b_; +}; + +// register dnn kernels for supported operations and supported types +#define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("_MklFusedMatMul") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T") \ + .Label(mkl_op_registry::kMklLayoutDependentOpLabel), \ + MklFusedMatMulOp); +TF_CALL_float(REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES); + +} // namespace tensorflow + +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h new file mode 100644 index 00000000000..decaac8aad7 --- /dev/null +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -0,0 +1,371 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_ +#define TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_ + +#ifdef INTEL_MKL +#include +#include +#include + +#include "mkldnn.hpp" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/util/mkl_util.h" + +using mkldnn::inner_product_forward; +using mkldnn::prop_kind; +using mkldnn::stream; + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// This structure aggregates multiple inputs to MklDnnMatMul* methods. +struct MklDnnMatMulFwdParams { + memory::dims src_dims; + memory::dims weight_dims; + memory::dims bias_dims; + memory::dims dst_dims; + memory::format weight_fmt; + string dtypes = string(""); + struct PostOpParam { + string name; + std::vector param; + }; + std::vector post_op_params; + + MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, + memory::dims bias_dims, memory::dims dst_dims, + memory::format weight_fmt = memory::format::any) + : src_dims(src_dims), + weight_dims(weight_dims), + bias_dims(bias_dims), + dst_dims(dst_dims), + weight_fmt(weight_fmt) {} +}; + +// With quantization, input, weight, bias, and output can have different types. +// So we use different template parameters for each type. +// TODO(intel-tf): The template type "T" is currently used to match the +// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h). +// In future, with the removal of "T" from MklPrimitiveFactory, this class +// needs to drop "T". +template +class MklDnnMatMulFwdPrimitive : public MklPrimitive { + public: + explicit MklDnnMatMulFwdPrimitive( + const MklDnnMatMulFwdParams& matmulFwdParams) + : cpu_engine_(engine::cpu, 0) { + context_.fwd_stream.reset(new stream(stream::kind::eager)); + // Create matmul primitive + if (context_.matmul_fwd == nullptr) { + Setup(matmulFwdParams); + } + } + + ~MklDnnMatMulFwdPrimitive() {} + + // Inner-product forward execute with bias: + // - src_data: input data buffer of src + // - weight_data: input data buffer of weight + // - bias_data: input data buffer of bias + // - dst_data: output data buffer of dst + void Execute(const Tinput* src_data, const Tweight* weight_data, + const Tbias* bias_data, Toutput* dst_data) { + context_.src_mem->set_data_handle( + static_cast(const_cast(src_data))); + context_.weight_mem->set_data_handle( + static_cast(const_cast(weight_data))); + context_.bias_mem->set_data_handle( + static_cast(const_cast(bias_data))); + context_.dst_mem->set_data_handle(static_cast(dst_data)); + context_.fwd_stream->submit(context_.fwd_primitives); + + // After execution, set data handle back + context_.src_mem->set_data_handle(DummyData); + context_.weight_mem->set_data_handle(DummyData); + context_.bias_mem->set_data_handle(DummyData); + context_.dst_mem->set_data_handle(DummyData); + } + + memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } + memory::format GetweightMemoryFormat() const { return context_.weight_fmt; } + std::shared_ptr + GetPrimitiveDesc() const { + return context_.fwd_pd; + } + + private: + // Primitive reuse context for inner-product Fwd op + struct MklDnnMatMulFwdContext { + // Expected memory format for this primitive instance + memory::format src_fmt; + memory::format weight_fmt; + + // MKL-DNN memory + std::shared_ptr src_mem; + std::shared_ptr weight_mem; + std::shared_ptr bias_mem; + std::shared_ptr dst_mem; + + // Descriptor and primitive-descriptor for forward inner-product + std::shared_ptr fwd_desc; + std::shared_ptr fwd_pd; + + // Memory descriptors + std::shared_ptr src_md; + std::shared_ptr weight_md; + std::shared_ptr bias_md; + std::shared_ptr dst_md; + + // Inner-product primitive + std::shared_ptr matmul_fwd; + std::shared_ptr fwd_stream; + std::vector fwd_primitives; + + MklDnnMatMulFwdContext() + : src_fmt(memory::format::any), + weight_fmt(memory::format::any), + src_mem(nullptr), + weight_mem(nullptr), + bias_mem(nullptr), + dst_mem(nullptr), + fwd_desc(nullptr), + src_md(nullptr), + weight_md(nullptr), + bias_md(nullptr), + fwd_pd(nullptr), + matmul_fwd(nullptr), + fwd_stream(nullptr) {} + }; + + void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) { + // Create memory descriptors for inner-product data with no specified format + context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims}, + MklDnnType(), + memory::format::any)); + + context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, + MklDnnType(), + matmul_fwd_params.weight_fmt)); + + context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, + MklDnnType(), + memory::format::any)); + + context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, + MklDnnType(), + memory::format::any)); + // Create an inner-product + context_.fwd_desc.reset(new inner_product_forward::desc( + prop_kind::forward_inference, *context_.src_md, *context_.weight_md, + *context_.bias_md, *context_.dst_md)); + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + // Check if there is any fusion as post-ops + auto const& post_op_params = matmul_fwd_params.post_op_params; + mkldnn::primitive_attr post_ops_attr; + mkldnn::post_ops post_ops; + if (!post_op_params.empty()) { + for (auto const& post_op_param : post_op_params) { + if (post_op_param.name == "relu") { + DCHECK_EQ(post_op_param.param.size(), 3); + float op_scale = post_op_param.param[0]; + float op_alpha = post_op_param.param[1]; + float op_beta = post_op_param.param[2]; + post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha, + op_beta); + } else if (post_op_param.name == "output_scale") { + DCHECK_EQ(post_op_param.param.size(), 1); + std::vector scales; + scales.push_back(post_op_param.param[0]); + post_ops_attr.set_output_scales(0, scales); + } else { + DCHECK((post_op_param.name == "relu") || + (post_op_param.name == "output_scale")); + } + } + post_ops_attr.set_post_ops(post_ops); + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, post_ops_attr, cpu_engine_)); + } else { + context_.fwd_pd.reset(new inner_product_forward::primitive_desc( + *context_.fwd_desc, cpu_engine_)); + } + + // Store the expected memory format + context_.src_fmt = static_cast( + context_.fwd_pd.get()->src_primitive_desc().desc().data.format); + + context_.weight_fmt = static_cast( + context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); + + // Create memory primitive based on dummy data + context_.src_mem.reset( + new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); + context_.weight_mem.reset( + new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); + context_.dst_mem.reset( + new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); + context_.bias_mem.reset(new memory({{{matmul_fwd_params.bias_dims}, + MklDnnType(), + memory::format::x}, + cpu_engine_}, + DummyData)); + + // Create inner-product primitive + context_.matmul_fwd.reset(new inner_product_forward( + *context_.fwd_pd, *context_.src_mem, *context_.weight_mem, + *context_.bias_mem, *context_.dst_mem)); + + context_.fwd_primitives.push_back(*context_.matmul_fwd); + return; + } + + struct MklDnnMatMulFwdContext context_; + engine cpu_engine_; +}; + +template +class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { + public: + static MklDnnMatMulFwdPrimitive* Get( + const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) { + MklDnnMatMulFwdPrimitive* matmul_fwd = + nullptr; + + if (do_not_cache) { + // Always create new primitive + matmul_fwd = + new MklDnnMatMulFwdPrimitive( + mkldnn_matmul_fwd_dims); + } else { + // try to find a suitable one in pool + matmul_fwd = dynamic_cast< + MklDnnMatMulFwdPrimitive*>( + MklDnnMatMulFwdPrimitiveFactory::GetInstance() + .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims)); + if (matmul_fwd == nullptr) { + matmul_fwd = + new MklDnnMatMulFwdPrimitive( + mkldnn_matmul_fwd_dims); + MklDnnMatMulFwdPrimitiveFactory::GetInstance() + .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd); + } + } + return matmul_fwd; + } + + private: + MklDnnMatMulFwdPrimitiveFactory() {} + ~MklDnnMatMulFwdPrimitiveFactory() {} + + static MklDnnMatMulFwdPrimitiveFactory& GetInstance() { + static MklDnnMatMulFwdPrimitiveFactory instance_; + return instance_; + } + + static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { + string prefix = "matmul_fwd_"; + FactoryKeyCreator key_creator; + key_creator.AddAsKey(prefix); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims); + key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes); + + // Generate keys for post-ops + for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { + if (post_op_param.name == "relu") { + DCHECK_EQ(post_op_param.param.size(), 3); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); + key_creator.AddAsKey(post_op_param.param[1]); + key_creator.AddAsKey(post_op_param.param[2]); + } else if (post_op_param.name == "output_scale") { + DCHECK_EQ(post_op_param.param.size(), 1); + key_creator.AddAsKey(post_op_param.name); + key_creator.AddAsKey(post_op_param.param[0]); + } else { + return string("not_a_key"); + } + } + return key_creator.GetKey(); + } + + MklPrimitive* GetMklDnnMatMulFwd( + const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { + string key = CreateKey(mkldnn_matmul_fwd_dims); + return this->GetOp(key); + } + + void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, + MklPrimitive* op) { + string key = CreateKey(mkldnn_matmul_fwd_dims); + this->SetOp(key, op); + } +}; + +template +class MklDnnMatMulOpBase : public OpKernel { + public: + explicit MklDnnMatMulOpBase(OpKernelConstruction* context) + : OpKernel(context) {} + void Compute(OpKernelContext* context) override = 0; + + // Allocate output tensor. + virtual void AllocateOutputTensor( + OpKernelContext* context, + const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc, + const memory::dims& output_dims_mkl_order, + memory::format output_tf_format, Tensor** output_tensor) { + DCHECK(output_tensor); + auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc(); + + MklDnnShape output_mkl_shape; + output_mkl_shape.SetMklTensor(true); + output_mkl_shape.SetMklLayout(&dst_pd); + output_mkl_shape.SetElemType(MklDnnType()); + output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), + output_dims_mkl_order, output_tf_format); + + TensorShape output_tf_shape; + output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); + + // Allocate Output Tensor + AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor, + output_tf_shape, output_mkl_shape); + } + + engine cpu_engine_ = engine(engine::cpu, 0); + + protected: + const int kInputIndexSrc = 0; + const int kInputIndexWeight = 1; + const int kInputIndexBias = 2; + const int kOutputIndexDst = 0; +}; + +} // namespace tensorflow + +#endif // INTEL_MKL +#endif // TENSORFLOW_CORE_KERNELS_MKL_MATMUL_OPS_COMMON_H_ diff --git a/tensorflow/core/kernels/mkl_qmatmul_op.cc b/tensorflow/core/kernels/mkl_qmatmul_op.cc index 4aff02ac827..6dce2766955 100644 --- a/tensorflow/core/kernels/mkl_qmatmul_op.cc +++ b/tensorflow/core/kernels/mkl_qmatmul_op.cc @@ -90,19 +90,12 @@ limitations under the License. // https://software.intel.com/en-us/articles/lower-numerical-precision-deep-learning-inference-and-training #ifdef INTEL_MKL -#include "mkldnn.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/kernels/fill_functor.h" +#include "tensorflow/core/kernels/mkl_matmul_ops_common.h" #include "tensorflow/core/kernels/mkl_quantized_conv_ops.h" #include "tensorflow/core/kernels/no_op.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/util/mkl_util.h" - -using mkldnn::inner_product_forward; -using mkldnn::prop_kind; -using mkldnn::stream; namespace { enum { @@ -113,299 +106,9 @@ enum { namespace tensorflow { -// This structure aggregates multiple inputs to MklDnnMatMul* methods. -struct MklDnnMatMulFwdParams { - memory::dims src_dims; - memory::dims weight_dims; - memory::dims bias_dims; - memory::dims dst_dims; - string dtypes = string(""); - struct PostOpParam { - string name; - std::vector param; - }; - std::vector post_op_params; - - MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, - memory::dims bias_dims, memory::dims dst_dims) - : src_dims(src_dims), - weight_dims(weight_dims), - bias_dims(bias_dims), - dst_dims(dst_dims) {} -}; - -// With quantization, input, weight, bias, and output can have different types. -// So we use different template parameters for each type. -// TODO(intel-tf): The template type "T" is currently used to match the -// templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h). -// In future, with the removal of "T" from MklPrimitiveFactory, this class -// needs to drop "T". -template -class MklDnnMatMulFwdPrimitive : public MklPrimitive { - public: - explicit MklDnnMatMulFwdPrimitive( - const MklDnnMatMulFwdParams& matmulFwdParams) - : cpu_engine_(engine::cpu, 0) { - context_.fwd_stream.reset(new stream(stream::kind::eager)); - // Create matmul primitive - if (context_.matmul_fwd == nullptr) { - Setup(matmulFwdParams); - } - } - - ~MklDnnMatMulFwdPrimitive() {} - - // Inner-product forward execute with bias: - // - src_data: input data buffer of src - // - weight_data: input data buffer of weight - // - bias_data: input data buffer of bias - // - dst_data: output data buffer of dst - void Execute(const Tinput* src_data, const Tweight* weight_data, - const Tbias* bias_data, Toutput* dst_data) { - context_.src_mem->set_data_handle( - static_cast(const_cast(src_data))); - context_.weight_mem->set_data_handle( - static_cast(const_cast(weight_data))); - context_.bias_mem->set_data_handle( - static_cast(const_cast(bias_data))); - context_.dst_mem->set_data_handle(static_cast(dst_data)); - context_.fwd_stream->submit(context_.fwd_primitives); - - // After execution, set data handle back - context_.src_mem->set_data_handle(DummyData); - context_.weight_mem->set_data_handle(DummyData); - context_.bias_mem->set_data_handle(DummyData); - context_.dst_mem->set_data_handle(DummyData); - } - - memory::format GetSrcMemoryFormat() const { return context_.src_fmt; } - memory::format GetweightMemoryFormat() const { return context_.weight_fmt; } - std::shared_ptr - GetPrimitiveDesc() const { - return context_.fwd_pd; - } - - private: - // Primitive reuse context for inner-product Fwd op - struct MklDnnMatMulFwdContext { - // Expected memory format for this primitive instance - memory::format src_fmt; - memory::format weight_fmt; - - // MKL-DNN memory - std::shared_ptr src_mem; - std::shared_ptr weight_mem; - std::shared_ptr bias_mem; - std::shared_ptr dst_mem; - - // Descriptor and primitive-descriptor for forward inner-product - std::shared_ptr fwd_desc; - std::shared_ptr fwd_pd; - - // Memory descriptors - std::shared_ptr src_md; - std::shared_ptr weight_md; - std::shared_ptr bias_md; - std::shared_ptr dst_md; - - // Inner-product primitive - std::shared_ptr matmul_fwd; - std::shared_ptr fwd_stream; - std::vector fwd_primitives; - - MklDnnMatMulFwdContext() - : src_fmt(memory::format::any), - weight_fmt(memory::format::any), - src_mem(nullptr), - weight_mem(nullptr), - bias_mem(nullptr), - dst_mem(nullptr), - fwd_desc(nullptr), - fwd_pd(nullptr), - src_md(nullptr), - weight_md(nullptr), - bias_md(nullptr), - matmul_fwd(nullptr), - fwd_stream(nullptr) {} - }; - - void Setup(const MklDnnMatMulFwdParams& matmul_fwd_params) { - // Create memory descriptors for inner-product data with no specified format - context_.src_md.reset(new memory::desc({matmul_fwd_params.src_dims}, - MklDnnType(), - memory::format::any)); - - context_.weight_md.reset(new memory::desc({matmul_fwd_params.weight_dims}, - MklDnnType(), - memory::format::any)); - - context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, - MklDnnType(), - memory::format::any)); - - context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, - MklDnnType(), - memory::format::any)); - // Create an inner-product - context_.fwd_desc.reset(new inner_product_forward::desc( - prop_kind::forward_inference, *context_.src_md, *context_.weight_md, - *context_.bias_md, *context_.dst_md)); - context_.fwd_pd.reset(new inner_product_forward::primitive_desc( - *context_.fwd_desc, cpu_engine_)); - // Check if there is any fusion as post-ops - auto const& post_op_params = matmul_fwd_params.post_op_params; - mkldnn::primitive_attr post_ops_attr; - mkldnn::post_ops post_ops; - if (!post_op_params.empty()) { - for (auto const& post_op_param : post_op_params) { - if (post_op_param.name == "relu") { - DCHECK_EQ(post_op_param.param.size(), 3); - float op_scale = post_op_param.param[0]; - float op_alpha = post_op_param.param[1]; - float op_beta = post_op_param.param[2]; - post_ops.append_eltwise(op_scale, mkldnn::eltwise_relu, op_alpha, - op_beta); - } else if (post_op_param.name == "output_scale") { - DCHECK_EQ(post_op_param.param.size(), 1); - std::vector scales; - scales.push_back(post_op_param.param[0]); - post_ops_attr.set_output_scales(0, scales); - } else { - DCHECK((post_op_param.name == "relu") || - (post_op_param.name == "output_scale")); - } - } - post_ops_attr.set_post_ops(post_ops); - context_.fwd_pd.reset(new inner_product_forward::primitive_desc( - *context_.fwd_desc, post_ops_attr, cpu_engine_)); - } else { - context_.fwd_pd.reset(new inner_product_forward::primitive_desc( - *context_.fwd_desc, cpu_engine_)); - } - - // Store the expected memory format - context_.src_fmt = static_cast( - context_.fwd_pd.get()->src_primitive_desc().desc().data.format); - - context_.weight_fmt = static_cast( - context_.fwd_pd.get()->weights_primitive_desc().desc().data.format); - - // Create memory primitive based on dummy data - context_.src_mem.reset( - new memory(context_.fwd_pd.get()->src_primitive_desc(), DummyData)); - context_.weight_mem.reset( - new memory(context_.fwd_pd.get()->weights_primitive_desc(), DummyData)); - context_.dst_mem.reset( - new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData)); - context_.bias_mem.reset(new memory({{{matmul_fwd_params.bias_dims}, - MklDnnType(), - memory::format::x}, - cpu_engine_}, - DummyData)); - - // Create inner-product primitive - context_.matmul_fwd.reset(new inner_product_forward( - *context_.fwd_pd, *context_.src_mem, *context_.weight_mem, - *context_.bias_mem, *context_.dst_mem)); - - context_.fwd_primitives.push_back(*context_.matmul_fwd); - return; - } - - struct MklDnnMatMulFwdContext context_; - engine cpu_engine_; -}; - -template -class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { - public: - static MklDnnMatMulFwdPrimitive* Get( - const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, bool do_not_cache) { - MklDnnMatMulFwdPrimitive* matmul_fwd = - nullptr; - - if (do_not_cache) { - // Always create new primitive - matmul_fwd = - new MklDnnMatMulFwdPrimitive( - mkldnn_matmul_fwd_dims); - } else { - // try to find a suitable one in pool - matmul_fwd = dynamic_cast< - MklDnnMatMulFwdPrimitive*>( - MklDnnMatMulFwdPrimitiveFactory::GetInstance() - .GetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims)); - if (matmul_fwd == nullptr) { - matmul_fwd = - new MklDnnMatMulFwdPrimitive( - mkldnn_matmul_fwd_dims); - MklDnnMatMulFwdPrimitiveFactory::GetInstance() - .SetMklDnnMatMulFwd(mkldnn_matmul_fwd_dims, matmul_fwd); - } - } - return matmul_fwd; - } - - private: - MklDnnMatMulFwdPrimitiveFactory() {} - ~MklDnnMatMulFwdPrimitiveFactory() {} - - static MklDnnMatMulFwdPrimitiveFactory& GetInstance() { - static MklDnnMatMulFwdPrimitiveFactory instance_; - return instance_; - } - - static string CreateKey(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { - string prefix = "matmul_fwd_"; - FactoryKeyCreator key_creator; - key_creator.AddAsKey(prefix); - key_creator.AddAsKey(mkldnn_matmul_fwd_dims.src_dims); - key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_dims); - key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims); - key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims); - key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes); - - // Generate keys for post-ops - for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) { - if (post_op_param.name == "relu") { - DCHECK_EQ(post_op_param.param.size(), 3); - key_creator.AddAsKey(post_op_param.name); - key_creator.AddAsKey(post_op_param.param[0]); - key_creator.AddAsKey(post_op_param.param[1]); - key_creator.AddAsKey(post_op_param.param[2]); - } else if (post_op_param.name == "output_scale") { - DCHECK_EQ(post_op_param.param.size(), 1); - key_creator.AddAsKey(post_op_param.name); - key_creator.AddAsKey(post_op_param.param[0]); - } else { - return string("not_a_key"); - } - } - return key_creator.GetKey(); - } - - MklPrimitive* GetMklDnnMatMulFwd( - const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims) { - string key = CreateKey(mkldnn_matmul_fwd_dims); - return this->GetOp(key); - } - - void SetMklDnnMatMulFwd(const MklDnnMatMulFwdParams& mkldnn_matmul_fwd_dims, - MklPrimitive* op) { - string key = CreateKey(mkldnn_matmul_fwd_dims); - this->SetOp(key, op); - } -}; - -typedef Eigen::ThreadPoolDevice CPUDevice; - template -class MklDnnQuantizedMatMulOp : public OpKernel { +class MklDnnQuantizedMatMulOp : public MklDnnMatMulOpBase { public: virtual ~MklDnnQuantizedMatMulOp() { if (this->input_bias_ != nullptr) { @@ -430,7 +133,7 @@ class MklDnnQuantizedMatMulOp : public OpKernel { } explicit MklDnnQuantizedMatMulOp(OpKernelConstruction* context) - : OpKernel(context) { + : MklDnnMatMulOpBase(context) { string mode_string; OP_REQUIRES_OK(context, context->GetAttr("input_quant_mode", &mode_string)); if (mode_string == "MIN_FIRST") { @@ -447,19 +150,20 @@ class MklDnnQuantizedMatMulOp : public OpKernel { void Compute(OpKernelContext* context) override { try { // Input tensors - const Tensor& src_tensor = MklGetInput(context, kInputIndexSrc); - const Tensor& weight_tensor = MklGetInput(context, kInputIndexWeight); - const Tensor& bias_tensor = MklGetInput(context, kInputIndexBias); + const Tensor& src_tensor = MklGetInput(context, this->kInputIndexSrc); + const Tensor& weight_tensor = + MklGetInput(context, this->kInputIndexWeight); + const Tensor& bias_tensor = MklGetInput(context, this->kInputIndexBias); MklDnnShape src_mkl_shape, weight_mkl_shape; - GetMklShape(context, kInputIndexSrc, &src_mkl_shape); - GetMklShape(context, kInputIndexWeight, &weight_mkl_shape); + GetMklShape(context, this->kInputIndexSrc, &src_mkl_shape); + GetMklShape(context, this->kInputIndexWeight, &weight_mkl_shape); OP_REQUIRES(context, !weight_mkl_shape.IsMklTensor(), errors::InvalidArgument("Weight should not be in " "MKL Layout")); - MklDnnData src(&cpu_engine_); - MklDnnData weight(&cpu_engine_); + MklDnnData src(&(this->cpu_engine_)); + MklDnnData weight(&(this->cpu_engine_)); memory::dims src_dims, weight_dims; memory::dims dst_dims_tf_order, dst_dims_mkl_order; @@ -524,8 +228,8 @@ class MklDnnQuantizedMatMulOp : public OpKernel { // Allocate output Tensor. std::shared_ptr matmul_fwd_pd = matmul_fwd->GetPrimitiveDesc(); - AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order, - input_output_fmt, &dst_tensor); + this->AllocateOutputTensor(context, *matmul_fwd_pd, dst_dims_mkl_order, + input_output_fmt, &dst_tensor); Toutput* dst_data = reinterpret_cast(dst_tensor->flat().data()); @@ -724,32 +428,6 @@ class MklDnnQuantizedMatMulOp : public OpKernel { } } - // Allocate output tensor. - virtual void AllocateOutputTensor( - OpKernelContext* context, - const inner_product_forward::primitive_desc& mkldnn_matmul_prim_desc, - const memory::dims& output_dims_mkl_order, - memory::format output_tf_format, Tensor** output_tensor) { - DCHECK(output_tensor); - auto dst_pd = mkldnn_matmul_prim_desc.dst_primitive_desc(); - - MklDnnShape output_mkl_shape; - output_mkl_shape.SetMklTensor(true); - output_mkl_shape.SetMklLayout(&dst_pd); - output_mkl_shape.SetElemType(MklDnnType()); - output_mkl_shape.SetTfLayout(output_dims_mkl_order.size(), - output_dims_mkl_order, output_tf_format, true); - - TensorShape output_tf_shape; - output_tf_shape.AddDim((dst_pd.get_size() / sizeof(Toutput))); - - // Allocate Output Tensor - AllocateOutputSetMklShape(context, kOutputIndexDst, output_tensor, - output_tf_shape, output_mkl_shape); - } - - engine cpu_engine_ = engine(engine::cpu, 0); - private: memory* input_bias_ = nullptr; memory* scaled_bias_ = nullptr; @@ -757,11 +435,6 @@ class MklDnnQuantizedMatMulOp : public OpKernel { // Buffer to save the compensated bias float* comp_bias_ = nullptr; - const int kInputIndexSrc = 0; - const int kInputIndexWeight = 1; - const int kInputIndexBias = 2; - const int kOutputIndexDst = 0; - int mode_; }; diff --git a/tensorflow/core/ops/mkl_nn_ops.cc b/tensorflow/core/ops/mkl_nn_ops.cc index 31c822a7704..9ea3fb61af0 100644 --- a/tensorflow/core/ops/mkl_nn_ops.cc +++ b/tensorflow/core/ops/mkl_nn_ops.cc @@ -60,6 +60,32 @@ REGISTER_OP("_MklFusedConv2D") is expected to create these operators. )doc"); +REGISTER_OP("_MklFusedMatMul") + .Input("a: T") + .Input("b: T") + .Input("args: num_args * T") + .Input("mkl_a: uint8") + .Input("mkl_b: uint8") + .Input("mkl_args: num_args * uint8") + .Output("product: T") + .Output("mkl_product: uint8") + .Attr("transpose_a: bool = false") + .Attr("transpose_b: bool = false") + .Attr("T: {float}") + .Attr("num_args: int >= 0") + .Attr("fused_ops: list(string) = []") + // Attributes for the FusedBatchNorm ----------- // + .Attr("epsilon: float = 0.0001") + // --------------------------------------------- // + .SetShapeFn(shape_inference::MatMulShape) + .Doc(R"doc( +MKL version of FusedMatMul operator. Uses MKL DNN APIs to implement MatMul +operator. + +NOTE Do not invoke this operator directly in Python. Graph rewrite pass is +expected to invoke these operators. +)doc"); + REGISTER_OP("__MklDummyPadWithFusedConv2D") .Input("input: T") .Input("filter: T") diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index ff218f24008..d4aebaf6341 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -468,7 +468,7 @@ class MklDnnShape { /// We use lazy evaluation and create it only when needed. Input format can /// also be Blocked format. inline void SetTfLayout(size_t dims, const memory::dims& sizes, - MKL_TENSOR_FORMAT format, bool is_2d = false) { + MKL_TENSOR_FORMAT format) { DCHECK_EQ(dims, sizes.size()) << "SetTfLayout: Number of dimensions does not" "match with dimension array"; @@ -478,7 +478,7 @@ class MklDnnShape { } data_.tf_data_format_ = format; if (format != MKL_TENSOR_FORMAT_BLOCKED) { - if (is_2d) { + if (dims == 2) { data_.map_[0] = MklDnnDims::Dim_N; data_.map_[1] = MklDnnDims::Dim_C; } else { @@ -688,9 +688,9 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor, CHECK(output_tensor.CopyFrom(mkl_tensor, output_shape)); } } catch (mkldnn::error& e) { - string error_msg = "Status: " + std::to_string(e.status) + - ", message: " + string(e.message) + ", in file " + - string(__FILE__) + ":" + std::to_string(__LINE__); + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); LOG(FATAL) << "Operation received an exception: " << error_msg; } return output_tensor; From c1c17655d68cd6d43d523d0fd83750a458adbff1 Mon Sep 17 00:00:00 2001 From: TengLu Date: Thu, 5 Sep 2019 17:45:42 +0800 Subject: [PATCH 2/4] Refine MKL-DNN FusedMatMul code according to review. --- tensorflow/core/graph/mkl_layout_pass.cc | 4 +- .../core/kernels/mkl_matmul_op_fused.cc | 64 +++++++++---------- .../core/kernels/mkl_matmul_ops_common.h | 7 +- tensorflow/core/ops/mkl_nn_ops.cc | 2 +- 4 files changed, 38 insertions(+), 39 deletions(-) diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 1a49cc5d3d3..981cde7337d 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -1454,7 +1454,7 @@ rinfo_.push_back({csinfo_.tanh_grad, return false; } - // Rewrite relu for _FusedMatMul. + // Rewrite rule for _FusedMatMul. // @return - true (no transpose attribute for input 1 and only has 1 post op); // false otherwise. static bool FusedMatMulRewrite(const Node* n) { @@ -1467,7 +1467,7 @@ rinfo_.push_back({csinfo_.tanh_grad, // Do not rewrite with more than 1 post op because MKL-DNN doesn't support. TF_CHECK_OK(GetNodeAttr(n->def(), "fused_ops", &fused_ops)); - return !trans_a && fused_ops.size() == 1; + return (!trans_a) && (fused_ops.size() == 1); } // Check if we are performing pooling on depth or batch. If it is, then we diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index 7baed255241..76df936418b 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -52,16 +52,16 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { const Tensor& weight_tensor = ctx->input(this->kInputIndexWeight); const Tensor& bias_tensor = MklGetInput(ctx, this->kInputIndexBias); - MklDnnShape src_dnn_shape; - MklDnnShape weight_dnn_shape; - GetMklShape(ctx, this->kInputIndexSrc, &src_dnn_shape); - GetMklShape(ctx, this->kInputIndexWeight, &weight_dnn_shape); + MklDnnShape src_mkl_shape; + MklDnnShape weight_mkl_shape; + GetMklShape(ctx, this->kInputIndexSrc, &src_mkl_shape); + GetMklShape(ctx, this->kInputIndexWeight, &weight_mkl_shape); // Get shapes of input tensors - auto src_tf_shape = src_dnn_shape.IsMklTensor() ? src_dnn_shape.GetTfShape() + auto src_tf_shape = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetTfShape() : src_tensor.shape(); - auto weight_tf_shape = weight_dnn_shape.IsMklTensor() - ? weight_dnn_shape.GetTfShape() + auto weight_tf_shape = weight_mkl_shape.IsMklTensor() + ? weight_mkl_shape.GetTfShape() : weight_tensor.shape(); // Check the constraint of input matrix and bias @@ -77,13 +77,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { // Get dimension size of each matrix, dim_pair[] is the location of k // in the inputs, we have constraint that k of the two inputs are // the same - const int dim_pair[] = {transpose_a_ ? 0 : 1, transpose_b_ ? 1 : 0}; + const int dim_pair[] = {1, transpose_b_ ? 1 : 0}; const int batch = src_tf_shape.dim_size(1 - dim_pair[0]); const int k = src_tf_shape.dim_size(dim_pair[0]); const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]), - errors::InvalidArgument("Matrix size-incompatible: In[0]: ", + errors::InvalidArgument("Matrix size are incompatible: In[0]: ", src_tf_shape.DebugString(), ", In[1]: ", weight_tf_shape.DebugString())); OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel, @@ -91,14 +91,12 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { "Must provide as many biases as the channel size: ", bias_tensor.shape().DebugString(), " vs. ", channel)); - // Create primitive for InnerProduct, it has such desc: - // s[batch, k] * w[channel, k] + b[channel] = dst[batch, channel] - // [n, c] * [oc, ic] + [x] = [n, c] - // - // For weights, dimensions need to be specified as [channel*k]. + // For inputs s[batch, k], w[k, channel] and b[channel], the primitive + // dims should be described like this: + // s[batch, k] * w^T[channel, k] + b[channel] = dst[batch, channel] + // [n, ic] * [oc, ic] + [oc] = [n, oc] memory::dims src_dims = memory::dims({batch, k}); - // In order to satisfy the primitive, reverse the dims of weights - // from [k, channel] to [channel, k]. + // Reverse the weights dims from [k, channel] to [channel, k]. memory::dims weight_dims = memory::dims({channel, k}); memory::dims bias_dims = memory::dims({channel}); memory::dims dst_dims = memory::dims({batch, channel}); @@ -115,15 +113,15 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { std::shared_ptr matmul_pd = matmul_prim->GetPrimitiveDesc(); - if (src_dnn_shape.IsMklTensor() && weight_dnn_shape.IsMklTensor()) { + if (src_mkl_shape.IsMklTensor() && weight_mkl_shape.IsMklTensor()) { this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, memory::format::nc, &dst_tensor); } else { TensorShape dst_tensor_shape({batch, channel}); - MklDnnShape dst_dnn_shape; - dst_dnn_shape.SetMklTensor(false); + MklDnnShape dst_mkl_shape; + dst_mkl_shape.SetMklTensor(false); AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape, - dst_dnn_shape); + dst_mkl_shape); } // if there's nothing to compute, just return. @@ -139,28 +137,28 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { T* dst_data = const_cast(dst_tensor->flat().data()); // Any input is MKL format, reorder it if necessary. - MklDnnData src_dnn(&(this->cpu_engine_)); - MklDnnData weight_dnn(&(this->cpu_engine_)); + MklDnnData src_mkl(&(this->cpu_engine_)); + MklDnnData weight_mkl(&(this->cpu_engine_)); - if (src_dnn_shape.IsMklTensor()) { - memory::desc input_md = src_dnn_shape.GetMklLayout(); + if (src_mkl_shape.IsMklTensor()) { + memory::desc input_md = src_mkl_shape.GetMklLayout(); if (input_md.data.format != memory::format::nc) { - src_dnn.SetUsrMem(input_md, src_data); - src_dnn.CheckReorderToOpMem(matmul_pd.get()->src_primitive_desc()); - src_data = reinterpret_cast(src_dnn.GetOpMem().get_data_handle()); + src_mkl.SetUsrMem(input_md, src_data); + src_mkl.CheckReorderToOpMem(matmul_pd.get()->src_primitive_desc()); + src_data = reinterpret_cast(src_mkl.GetOpMem().get_data_handle()); } } - if (weight_dnn_shape.IsMklTensor()) { - memory::desc input_md = weight_dnn_shape.GetMklLayout(); + if (weight_mkl_shape.IsMklTensor()) { + memory::desc input_md = weight_mkl_shape.GetMklLayout(); if (input_md.data.format != weight_format) { - weight_dnn.SetUsrMem(input_md, weight_data); - weight_dnn.CheckReorderToOpMem( + weight_mkl.SetUsrMem(input_md, weight_data); + weight_mkl.CheckReorderToOpMem( matmul_pd.get()->weights_primitive_desc()); weight_data = - reinterpret_cast(weight_dnn.GetOpMem().get_data_handle()); + reinterpret_cast(weight_mkl.GetOpMem().get_data_handle()); } } @@ -179,7 +177,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { bool transpose_b_; }; -// register dnn kernels for supported operations and supported types +// Register mkl kernels for supported operations and types. #define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type) \ REGISTER_KERNEL_BUILDER( \ Name("_MklFusedMatMul") \ diff --git a/tensorflow/core/kernels/mkl_matmul_ops_common.h b/tensorflow/core/kernels/mkl_matmul_ops_common.h index decaac8aad7..fbea4335f56 100644 --- a/tensorflow/core/kernels/mkl_matmul_ops_common.h +++ b/tensorflow/core/kernels/mkl_matmul_ops_common.h @@ -62,7 +62,7 @@ struct MklDnnMatMulFwdParams { // So we use different template parameters for each type. // TODO(intel-tf): The template type "T" is currently used to match the // templatized class MklPrimitiveFactory (tensorflow/core/util/mkl_util.h). -// In future, with the removal of "T" from MklPrimitiveFactory, this class +// In the future, with the removal of "T" from MklPrimitiveFactory, this class // needs to drop "T". template @@ -146,10 +146,11 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { bias_mem(nullptr), dst_mem(nullptr), fwd_desc(nullptr), + fwd_pd(nullptr), src_md(nullptr), weight_md(nullptr), bias_md(nullptr), - fwd_pd(nullptr), + dst_md(nullptr), matmul_fwd(nullptr), fwd_stream(nullptr) {} }; @@ -256,7 +257,7 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory { new MklDnnMatMulFwdPrimitive( mkldnn_matmul_fwd_dims); } else { - // try to find a suitable one in pool + // Try to find a suitable one in pool matmul_fwd = dynamic_cast< MklDnnMatMulFwdPrimitive*>( MklDnnMatMulFwdPrimitiveFactory Date: Wed, 11 Sep 2019 23:07:05 +0800 Subject: [PATCH 3/4] Fix Ubuntu Sanity errors for MKL-DNN FusedMatMul. --- tensorflow/core/kernels/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index c7e6fd2dbc6..ab4cba803cc 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -3816,7 +3816,7 @@ tf_mkl_kernel_library( name = "mkl_matmul_op", srcs = [ "mkl_matmul_op.cc", - "mkl_matmul_op_fused.cc" + "mkl_matmul_op_fused.cc", ], hdrs = ["mkl_matmul_ops_common.h"], deps = MATH_DEPS + mkl_deps(), From 94ae1e0352b795072a23089ca1dcc11203b8fd2d Mon Sep 17 00:00:00 2001 From: TengLu Date: Mon, 23 Sep 2019 11:34:44 +0800 Subject: [PATCH 4/4] Fix UT failures for fp32 MklFusedMatMul. --- tensorflow/core/kernels/mkl_matmul_op_fused.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/mkl_matmul_op_fused.cc b/tensorflow/core/kernels/mkl_matmul_op_fused.cc index 76df936418b..f6fca9d0f8d 100644 --- a/tensorflow/core/kernels/mkl_matmul_op_fused.cc +++ b/tensorflow/core/kernels/mkl_matmul_op_fused.cc @@ -83,7 +83,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase { const int channel = weight_tf_shape.dim_size(1 - dim_pair[1]); OP_REQUIRES(ctx, k == weight_tf_shape.dim_size(dim_pair[1]), - errors::InvalidArgument("Matrix size are incompatible: In[0]: ", + errors::InvalidArgument("Matrix size-incompatible: In[0]: ", src_tf_shape.DebugString(), ", In[1]: ", weight_tf_shape.DebugString())); OP_REQUIRES(ctx, bias_tensor.shape().dim_size(0) == channel,