Merge pull request #45358 from Intel-tensorflow:yanzhang/matmul_biasadd_add_fusion
PiperOrigin-RevId: 347408532 Change-Id: I6e8b12dfef056f095449fd70ce387b79d3e8b4d7
This commit is contained in:
		
						commit
						cd052fa5f0
					
				| @ -446,6 +446,84 @@ TEST_F(MklRemapperTest, FuseBatchNormWithRelu) { | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| TEST_F(MklRemapperTest, FuseMatMulWithBiasAddAndAdd) { | ||||
|   using ::tensorflow::ops::Placeholder; | ||||
| 
 | ||||
|   tensorflow::Scope s = tensorflow::Scope::NewRootScope(); | ||||
| 
 | ||||
|   auto input_shape = ops::Placeholder::Shape({4, 32}); | ||||
|   auto input_shape_add = ops::Placeholder::Shape({4, 8}); | ||||
|   auto filter_shape = ops::Placeholder::Shape({32, 8}); | ||||
|   auto bias_shape = ops::Placeholder::Shape({8}); | ||||
| 
 | ||||
|   auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape); | ||||
|   auto input_add = | ||||
|       Placeholder(s.WithOpName("input_add"), DT_FLOAT, input_shape_add); | ||||
|   auto filter = Placeholder(s.WithOpName("filter"), DT_FLOAT, filter_shape); | ||||
|   auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape); | ||||
| 
 | ||||
|   auto matmul = ops::MatMul(s.WithOpName("matmul"), input, filter); | ||||
|   auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias); | ||||
| 
 | ||||
|   auto fetch = s.WithOpName("fetch"); | ||||
|   auto add = ops::Add(s.WithOpName("add"), bias_add, input_add); | ||||
| 
 | ||||
|   ops::Identity(fetch, add); | ||||
| 
 | ||||
|   auto input_tensor = GenerateRandomTensor<DT_FLOAT>( | ||||
|       TensorShape(input_shape.shape_.dim_sizes())); | ||||
|   auto input_add_tensor = GenerateRandomTensor<DT_FLOAT>( | ||||
|       TensorShape(input_shape_add.shape_.dim_sizes())); | ||||
|   auto filter_tensor = GenerateRandomTensor<DT_FLOAT>( | ||||
|       TensorShape(filter_shape.shape_.dim_sizes())); | ||||
|   auto bias_tensor = GenerateRandomTensor<DT_FLOAT>( | ||||
|       TensorShape(bias_shape.shape_.dim_sizes())); | ||||
| 
 | ||||
|   GrapplerItem item; | ||||
|   item.fetch = {"fetch"}; | ||||
|   item.feed = {{"input", input_tensor}, | ||||
|                {"filter", filter_tensor}, | ||||
|                {"bias", bias_tensor}, | ||||
|                {"input_add", input_add_tensor}}; | ||||
|   TF_CHECK_OK(s.ToGraphDef(&item.graph)); | ||||
| 
 | ||||
|   // Place all nodes on CPU.
 | ||||
|   for (int i = 0; i < item.graph.node_size(); ++i) { | ||||
|     item.graph.mutable_node(i)->set_device("/device:CPU:0"); | ||||
|   } | ||||
| 
 | ||||
|   Remapper optimizer(RewriterConfig::AGGRESSIVE); | ||||
|   GraphDef output; | ||||
|   TF_CHECK_OK(optimizer.Optimize(nullptr, item, &output)); | ||||
| 
 | ||||
|   int found = 0; | ||||
|   for (const NodeDef& node : output.node()) { | ||||
|     auto fetch_node_name = "add"; | ||||
|     if (node.name() == fetch_node_name) { | ||||
|       EXPECT_EQ("_FusedMatMul", node.op()); | ||||
|       EXPECT_EQ("input", node.input(0)); | ||||
|       EXPECT_EQ("filter", node.input(1)); | ||||
| 
 | ||||
|       EXPECT_EQ(2, node.attr().at("num_args").i()); | ||||
|       EXPECT_EQ("bias", node.input(2)); | ||||
|       EXPECT_EQ("input_add", node.input(3)); | ||||
| 
 | ||||
|       const auto fused_ops = node.attr().at("fused_ops").list().s(); | ||||
|       EXPECT_EQ(2, fused_ops.size()); | ||||
|       EXPECT_EQ("BiasAdd", fused_ops[0]); | ||||
|       EXPECT_EQ("Add", fused_ops[1]); | ||||
|       found++; | ||||
|     } | ||||
|   } | ||||
|   EXPECT_EQ(1, found); | ||||
| 
 | ||||
|   auto tensors_expected = EvaluateNodes(item.graph, item.fetch, item.feed); | ||||
|   auto tensors = EvaluateNodes(output, item.fetch, item.feed); | ||||
|   EXPECT_EQ(1, tensors_expected.size()); | ||||
|   EXPECT_EQ(1, tensors.size()); | ||||
|   test::ExpectClose(tensors_expected[0], tensors[0], 0, 1e-6); | ||||
| } | ||||
| #endif  // ENABLE_MKLDNN_V1
 | ||||
| 
 | ||||
| }  // namespace grappler
 | ||||
|  | ||||
| @ -1283,28 +1283,36 @@ Status AddFusedContractionNode(RemapperContext* ctx, | ||||
|   const NodeDef& contraction = graph->node(matched.contraction); | ||||
|   const NodeDef& bias_add = graph->node(matched.bias_add); | ||||
| 
 | ||||
|   // MKL version only support fusion for Conv2D
 | ||||
|   DCHECK(IsConv2D(contraction)); | ||||
|   // MKL version only support fusion for Conv2D and MatMul
 | ||||
|   DCHECK(IsConv2D(contraction) || IsMatMul(contraction)); | ||||
| 
 | ||||
|   NodeDef fused_conv2d; | ||||
|   NodeDef contraction_node; | ||||
|   const NodeDef& add = graph->node(matched.add); | ||||
|   fused_conv2d.set_name(add.name()); | ||||
|   fused_conv2d.set_op(kFusedConv2D); | ||||
|   fused_conv2d.set_device(contraction.device()); | ||||
|   fused_conv2d.add_input(contraction.input(0));  // 0: input
 | ||||
|   fused_conv2d.add_input(contraction.input(1));  // 1: filter
 | ||||
|   fused_conv2d.add_input(bias_add.input(1));     // 2: bias
 | ||||
|   contraction_node.set_name(add.name()); | ||||
|   contraction_node.set_device(contraction.device()); | ||||
|   contraction_node.add_input( | ||||
|       contraction.input(0));  // 0: input(conv) / a (matmul)
 | ||||
|   contraction_node.add_input( | ||||
|       contraction.input(1));  // 1: filter(conv) / b (matmul)
 | ||||
|   contraction_node.add_input(bias_add.input(1));  // 2: bias
 | ||||
| 
 | ||||
|   // Add OP has two inputs, one is conv+bias pattern matched previously,
 | ||||
|   // the other input to add is fused here.
 | ||||
|   fused_conv2d.add_input(add.input(1 - matched.port_id)); | ||||
|   // Add OP has two inputs, one is conv+bias/matmul+bias pattern matched
 | ||||
|   // previously, the other input to add is fused here.
 | ||||
|   contraction_node.add_input(add.input(1 - matched.port_id)); | ||||
| 
 | ||||
|   CopyConv2DAttributes(contraction, &fused_conv2d); | ||||
|   SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add"}, 2); | ||||
|   if (IsConv2D(contraction)) { | ||||
|     contraction_node.set_op(kFusedConv2D); | ||||
|     CopyConv2DAttributes(contraction, &contraction_node); | ||||
|   } else if (IsMatMul(contraction)) { | ||||
|     contraction_node.set_op(kFusedMatMul); | ||||
|     CopyMatMulAttributes(contraction, &contraction_node); | ||||
|   } | ||||
| 
 | ||||
|   SetFusedOpAttributes(&contraction_node, {"BiasAdd", "Add"}, 2); | ||||
| 
 | ||||
|   utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder(); | ||||
|   Status status; | ||||
|   mutation->AddNode(std::move(fused_conv2d), &status); | ||||
|   mutation->AddNode(std::move(contraction_node), &status); | ||||
|   TF_RETURN_IF_ERROR(status); | ||||
|   TF_RETURN_IF_ERROR(mutation->Apply()); | ||||
| 
 | ||||
| @ -1621,19 +1629,25 @@ Status AddBatchNormNodes(RemapperContext* ctx, const FusedBatchNorm& matched) { | ||||
| } | ||||
| 
 | ||||
| #ifdef INTEL_MKL | ||||
| bool IsConv2DWithAdd(const RemapperContext& ctx, int node_index) { | ||||
| bool IsConv2DOrMatMul(const NodeDef& node) { | ||||
|   return IsConv2D(node) || IsMatMul(node); | ||||
| } | ||||
| 
 | ||||
| bool IsContractionWithAdd(const RemapperContext& ctx, int node_index) { | ||||
|   const auto* node_view = ctx.graph_view.GetNode(node_index); | ||||
|   const auto* node_def = node_view->node(); | ||||
| 
 | ||||
|   // Candidate for Conv2D + Add or Conv2D + BiasAdd + Add fusion.
 | ||||
|   //               MatMul + Add or MatMul + BiasAdd + Add fusion.
 | ||||
|   auto is_supported_add_input = [](const auto* node_view) -> bool { | ||||
|     if (IsConv2D(*node_view->node())) return true; | ||||
|     // Currently only support Conv2D and MatMul
 | ||||
|     if (IsConv2DOrMatMul(*node_view->node())) return true; | ||||
|     if (IsBiasAdd(*node_view->node())) { | ||||
|       if (node_view->NumRegularFanins() < 2) return false; | ||||
|       const auto& bias_add_fanin_0 = node_view->GetRegularFanin(0); | ||||
|       const auto& bias_add_fanin_1 = node_view->GetRegularFanin(1); | ||||
|       return IsConv2D(*bias_add_fanin_0.node_view()->node()) || | ||||
|              IsConv2D(*bias_add_fanin_1.node_view()->node()); | ||||
|       return IsConv2DOrMatMul(*bias_add_fanin_0.node_view()->node()) || | ||||
|              IsConv2DOrMatMul(*bias_add_fanin_1.node_view()->node()); | ||||
|     } | ||||
|     return false; | ||||
|   }; | ||||
| @ -1739,7 +1753,7 @@ bool RequiresInferredShapes(const RemapperContext& ctx, int node_index) { | ||||
| 
 | ||||
| #ifdef INTEL_MKL | ||||
|   return is_batch_norm_candidate() || is_batch_norm_fusion_candidate() || | ||||
|          IsConv2DWithAdd(ctx, node_index); | ||||
|          IsContractionWithAdd(ctx, node_index); | ||||
| #else | ||||
|   return is_relu_biasadd_conv2d_candidate() || is_batch_norm_candidate() || | ||||
|          is_batch_norm_fusion_candidate(); | ||||
|  | ||||
| @ -955,6 +955,71 @@ TEST_F(FilterCacheTest, Conv2DFilterCacheTest) { | ||||
| // Testing fusion of MatMul and BiasAdd
 | ||||
| template <typename T> | ||||
| class MklFusedMatMulOpTest : public OpsTestBase { | ||||
|  private: | ||||
|   void RunMklFusedMatMulOp(const Tensor& input, const Tensor& weight, | ||||
|                            const std::vector<Tensor>& args, | ||||
|                            const std::vector<string>& fused_ops, | ||||
|                            Tensor* output) { | ||||
|     DataType dtype = DataTypeToEnum<T>::v(); | ||||
|     const int num_args = args.size(); | ||||
|     if (!NativeFormatEnabled()) { | ||||
|       TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") | ||||
|                        .Input(FakeInput(dtype)) | ||||
|                        .Input(FakeInput(dtype)) | ||||
|                        .Input(FakeInput(num_args, dtype)) | ||||
|                        .Input(FakeInput(DT_UINT8)) | ||||
|                        .Input(FakeInput(DT_UINT8)) | ||||
|                        .Input(FakeInput(num_args, DT_UINT8)) | ||||
|                        .Attr("T", dtype) | ||||
|                        .Attr("transpose_a", false) | ||||
|                        .Attr("transpose_b", false) | ||||
|                        .Attr("num_args", num_args) | ||||
|                        .Attr("fused_ops", fused_ops) | ||||
|                        .Attr("epsilon", 0.0001) | ||||
|                        .Attr("_kernel", "MklLayoutDependentOp") | ||||
|                        .Finalize(node_def())); | ||||
|     } else { | ||||
|       TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul") | ||||
|                        .Input(FakeInput(dtype)) | ||||
|                        .Input(FakeInput(dtype)) | ||||
|                        .Input(FakeInput(num_args, dtype)) | ||||
|                        .Attr("T", dtype) | ||||
|                        .Attr("transpose_a", false) | ||||
|                        .Attr("transpose_b", false) | ||||
|                        .Attr("num_args", num_args) | ||||
|                        .Attr("fused_ops", fused_ops) | ||||
|                        .Attr("epsilon", 0.0001) | ||||
|                        .Attr("_kernel", "MklNameChangeOp") | ||||
|                        .Finalize(node_def())); | ||||
|     } | ||||
| 
 | ||||
|     TF_EXPECT_OK(InitOp()); | ||||
| 
 | ||||
|     AddInputFromArray<T>(input.shape(), input.flat<T>()); | ||||
|     AddInputFromArray<T>(weight.shape(), weight.flat<T>()); | ||||
|     for (const Tensor& arg : args) | ||||
|       AddInputFromArray<T>(arg.shape(), arg.flat<T>()); | ||||
|     if (!NativeFormatEnabled()) { | ||||
|       // Add MKL meta input for input, filter and bias.
 | ||||
|       AddInputFromArray<uint8>(dummy_shape, dummy_tensor); | ||||
|       AddInputFromArray<uint8>(dummy_shape, dummy_tensor); | ||||
|       for (const Tensor& arg : args) | ||||
|         AddInputFromArray<uint8>(dummy_shape, dummy_tensor); | ||||
|     } | ||||
| 
 | ||||
|     TF_ASSERT_OK(RunOpKernel()); | ||||
| 
 | ||||
|     const Tensor& output_tensor = *GetOutput(0); | ||||
|     if (!NativeFormatEnabled()) { | ||||
|       const Tensor& output_meta_tensor = *GetOutput(1); | ||||
|       CommonTestUtilities<T> test_util; | ||||
|       test_util.PerformConversion(dtype, output_tensor, output_meta_tensor, | ||||
|                                   output); | ||||
|     } else { | ||||
|       *output = output_tensor; | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|  protected: | ||||
|   void VerifyFusedMatMul(const int kBatch, const int kInputChannel, | ||||
|                          const int kOutputChannel, | ||||
| @ -1002,70 +1067,24 @@ class MklFusedMatMulOpTest : public OpsTestBase { | ||||
|             next_op = ops::Tanh(root.WithOpName(last_op), next_op); | ||||
|           } | ||||
| 
 | ||||
|           if (std::find(fused_ops.begin(), fused_ops.end(), "Add") != | ||||
|               fused_ops.end()) { | ||||
|             last_op = "with_add"; | ||||
|             next_op = ops::Add(root.WithOpName("with_add"), next_op, input_op); | ||||
|           } | ||||
| 
 | ||||
|           CommonTestUtilities<T>::RunAndFetch(root, last_op, output); | ||||
|         }; | ||||
| 
 | ||||
|     const FusedGraphRunner run_fused = | ||||
|         [this](const Tensor& input, const Tensor& weight, const Tensor& bias, | ||||
|                const std::vector<string>& fused_ops, Tensor* output) { | ||||
|           DataType dtype = DataTypeToEnum<T>::v(); | ||||
|           const int num_args = 1; | ||||
| 
 | ||||
|           if (!NativeFormatEnabled()) { | ||||
|             TF_EXPECT_OK(NodeDefBuilder("MklFusedMatMul", "_MklFusedMatMul") | ||||
|                              .Input(FakeInput(dtype)) | ||||
|                              .Input(FakeInput(dtype)) | ||||
|                              .Input(FakeInput(num_args, dtype)) | ||||
|                              .Input(FakeInput(DT_UINT8)) | ||||
|                              .Input(FakeInput(DT_UINT8)) | ||||
|                              .Input(FakeInput(num_args, DT_UINT8)) | ||||
|                              .Attr("T", dtype) | ||||
|                              .Attr("transpose_a", false) | ||||
|                              .Attr("transpose_b", false) | ||||
|                              .Attr("num_args", num_args) | ||||
|                              .Attr("fused_ops", fused_ops) | ||||
|                              .Attr("epsilon", 0.0001) | ||||
|                              .Attr("_kernel", "MklLayoutDependentOp") | ||||
|                              .Finalize(node_def())); | ||||
|           } else { | ||||
|             TF_EXPECT_OK( | ||||
|                 NodeDefBuilder("MklFusedMatMul", "_MklNativeFusedMatMul") | ||||
|                     .Input(FakeInput(dtype)) | ||||
|                     .Input(FakeInput(dtype)) | ||||
|                     .Input(FakeInput(num_args, dtype)) | ||||
|                     .Attr("T", dtype) | ||||
|                     .Attr("transpose_a", false) | ||||
|                     .Attr("transpose_b", false) | ||||
|                     .Attr("num_args", num_args) | ||||
|                     .Attr("fused_ops", fused_ops) | ||||
|                     .Attr("epsilon", 0.0001) | ||||
|                     .Attr("_kernel", "MklNameChangeOp") | ||||
|                     .Finalize(node_def())); | ||||
|           } | ||||
| 
 | ||||
|           TF_EXPECT_OK(InitOp()); | ||||
| 
 | ||||
|           AddInputFromArray<T>(input.shape(), input.flat<T>()); | ||||
|           AddInputFromArray<T>(weight.shape(), weight.flat<T>()); | ||||
|           AddInputFromArray<T>(bias.shape(), bias.flat<T>()); | ||||
|           if (!NativeFormatEnabled()) { | ||||
|             // Add MKL meta input for input, filter and bias.
 | ||||
|             AddInputFromArray<uint8>(dummy_shape, dummy_tensor); | ||||
|             AddInputFromArray<uint8>(dummy_shape, dummy_tensor); | ||||
|             AddInputFromArray<uint8>(dummy_shape, dummy_tensor); | ||||
|           } | ||||
| 
 | ||||
|           TF_ASSERT_OK(RunOpKernel()); | ||||
| 
 | ||||
|           const Tensor& output_tensor = *GetOutput(0); | ||||
|           if (!NativeFormatEnabled()) { | ||||
|             const Tensor& output_meta_tensor = *GetOutput(1); | ||||
|             CommonTestUtilities<T> test_util; | ||||
|             test_util.PerformConversion(dtype, output_tensor, | ||||
|                                         output_meta_tensor, output); | ||||
|           } else { | ||||
|             *output = output_tensor; | ||||
|           std::vector<Tensor> fused_input = {bias}; | ||||
|           if (std::find(fused_ops.begin(), fused_ops.end(), "Add") != | ||||
|               fused_ops.end()) { | ||||
|             fused_input.push_back(input); | ||||
|           } | ||||
|           RunMklFusedMatMulOp(input, weight, fused_input, fused_ops, output); | ||||
|         }; | ||||
| 
 | ||||
|     CommonTestUtilities<T>::VerifyFusedMatrixClose(kInputChannel, kBatch, | ||||
| @ -1120,12 +1139,22 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) { | ||||
|                           {"BiasAdd", "Tanh"}); | ||||
| } | ||||
| 
 | ||||
| TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndAdd) { | ||||
|   const int batch = 3; | ||||
|   const int input_channel = 4; | ||||
|   const int output_channel = 4; | ||||
| 
 | ||||
|   this->VerifyFusedMatMul(batch, input_channel, output_channel, | ||||
|                           {"BiasAdd", "Add"}); | ||||
| } | ||||
| 
 | ||||
| REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest,  //
 | ||||
|                             WithBias,              //
 | ||||
|                             WithBiasAndRelu,       //
 | ||||
|                             WithBiasAndRelu6,      //
 | ||||
|                             WithBiasAndElu,        //
 | ||||
|                             WithBiasAndTanh); | ||||
|                             WithBiasAndTanh,       //
 | ||||
|                             WithBiasAndAdd); | ||||
| 
 | ||||
| using MklFusedMatMulDataTypes = ::testing::Types<float>; | ||||
| INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest, | ||||
|  | ||||
| @ -45,6 +45,7 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { | ||||
|         ctx, fused_ops_[0] == "BiasAdd", | ||||
|         errors::InvalidArgument( | ||||
|             "The 1st post-argument of MklFusedMatMul must be BiasAdd.")); | ||||
|     if (fused_ops_.size() > 1 && fused_ops_[1] == "Add") fuse_add_ = true; | ||||
|     OP_REQUIRES( | ||||
|         ctx, transpose_a_ == false, | ||||
|         errors::InvalidArgument("In[0] of MklMatMul can't be transposed.")); | ||||
| @ -114,7 +115,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { | ||||
|     //   2. var, keep the original format to avoid reordering.
 | ||||
|     MklDnnMatMulFwdParams matmul_params( | ||||
|         src_dims, weight_dims, bias_dims, dst_dims, src_format, | ||||
|         (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format); | ||||
|         (this->is_weight_const_) ? MEMORY_FORMAT::any : weight_format, | ||||
|         MEMORY_FORMAT::nc); | ||||
| 
 | ||||
|     // Extend the basic parameters for data types and fusions.
 | ||||
|     ExtendMklDnnMatMulFwdParams(ctx, matmul_params); | ||||
| @ -126,15 +128,70 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { | ||||
|     std::shared_ptr<mkldnn::inner_product_forward::primitive_desc> matmul_pd = | ||||
|         matmul_prim->GetPrimitiveDesc(); | ||||
| 
 | ||||
|     if (src_mkl_shape.IsMklTensor()) { | ||||
|       this->AllocateOutputTensor(ctx, *matmul_pd, dst_dims, | ||||
|                                  MKL_TENSOR_FORMAT_NC, &dst_tensor); | ||||
|     // The output shape of MatMul is same both for MKL and TF version.
 | ||||
|     // They are all NC format, no matter what's the format of input.
 | ||||
|     // And the shape of AddOp is also the same with output's shape.
 | ||||
|     auto dst_pd = matmul_pd->PRIMITIVE_DESC_DST; | ||||
| 
 | ||||
|     MklDnnShape output_mkl_shape; | ||||
|     output_mkl_shape.SetMklTensor(false); | ||||
| 
 | ||||
|     TensorShape output_tf_shape({batch, channel}); | ||||
| 
 | ||||
|     if (fuse_add_) { | ||||
|       const Tensor& add_tensor = MklGetInput(ctx, kInputIndex_Add); | ||||
|       MklDnnShape add_mkl_shape; | ||||
|       GetMklShape(ctx, kInputIndex_Add, &add_mkl_shape, native_format); | ||||
| 
 | ||||
|       // For native format, we need not to set metadata.
 | ||||
|       if (native_format && ctx->forward_input_to_output_with_shape( | ||||
|                                kInputIndex_Add, kOutputIndex_Dst, | ||||
|                                output_tf_shape, &dst_tensor)) { | ||||
|         ;  // Need to do nothing for native format
 | ||||
|       } else if (!native_format && ForwardMklTensorInToOutWithMklShape( | ||||
|                                        ctx, kInputIndex_Add, kOutputIndex_Dst, | ||||
|                                        &dst_tensor, output_mkl_shape, false)) { | ||||
|         ;  // If it's not native format, need to forward and set meta first
 | ||||
|       } else { | ||||
|         // If forward is not successful, we should use reorder to copy add
 | ||||
|         // tensor to dst tensor
 | ||||
|         AllocateOutputSetMklShape(ctx, kOutputIndex_Dst, &dst_tensor, | ||||
|                                   output_tf_shape, output_mkl_shape, | ||||
|                                   native_format); | ||||
|         auto output_format_tag = | ||||
|             MklTensorFormatToMklDnnDataFormat(MKL_TENSOR_FORMAT_NC); | ||||
|         auto add_md = | ||||
|             add_mkl_shape.IsMklTensor() | ||||
|                 ? add_mkl_shape.GetMklLayout() | ||||
|                 : memory::desc(dst_dims, MklDnnType<T>(), output_format_tag); | ||||
|         auto dst_md = | ||||
|             memory::desc(dst_dims, MklDnnType<T>(), output_format_tag); | ||||
| 
 | ||||
|         void* add_buf = | ||||
|             static_cast<void*>(const_cast<T*>(add_tensor.flat<T>().data())); | ||||
|         void* dst_buf = static_cast<void*>((dst_tensor)->flat<T>().data()); | ||||
| 
 | ||||
|         if (native_format) { | ||||
|           // We are simply deep copying the add_tensor to dst_tensor without
 | ||||
|           // changing memory layout, hence using same memory descriptor.
 | ||||
|           add_md = dst_md = | ||||
|               memory::desc({add_tensor.NumElements()}, MklDnnType<T>(), | ||||
|                            mkldnn::memory::format_tag::x); | ||||
|         } | ||||
| 
 | ||||
|         auto fuse_add_src_ = | ||||
|             MEMORY_CONSTRUCTOR(ADD_MD, this->cpu_engine_, add_buf); | ||||
|         auto fuse_add_dst_ = | ||||
|             MEMORY_CONSTRUCTOR(DST_MD, this->cpu_engine_, dst_buf); | ||||
|         auto reorder_desc = | ||||
|             REORDER_PD_CONSTRUCTOR(ADD_MD, DST_MD, this->cpu_engine_); | ||||
| 
 | ||||
|         CreateAndExecuteReorder(reorder_desc, fuse_add_src_, fuse_add_dst_, | ||||
|                                 this->cpu_engine_, ctx); | ||||
|       } | ||||
|     } else { | ||||
|       TensorShape dst_tensor_shape({batch, channel}); | ||||
|       MklDnnShape dst_mkl_shape; | ||||
|       dst_mkl_shape.SetMklTensor(false); | ||||
|       AllocateOutputSetMklShape(ctx, 0, &dst_tensor, dst_tensor_shape, | ||||
|                                 dst_mkl_shape, native_format); | ||||
|       AllocateOutputSetMklShape(ctx, 0, &dst_tensor, output_tf_shape, | ||||
|                                 output_mkl_shape, native_format); | ||||
|     } | ||||
| 
 | ||||
|     // if there's nothing to compute, just return.
 | ||||
| @ -228,6 +285,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { | ||||
|         params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}}); | ||||
|       } else if (post_op == "Tanh") { | ||||
|         params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}}); | ||||
|       } else if (post_op == "Add") { | ||||
|         params.post_op_params.push_back({"sum", {1.0}}); | ||||
|       } else { | ||||
|         OP_REQUIRES_OK( | ||||
|             ctx, errors::InvalidArgument( | ||||
| @ -237,10 +296,13 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> { | ||||
|   } | ||||
| 
 | ||||
|  private: | ||||
|   bool fuse_add_ = false; | ||||
|   bool transpose_a_; | ||||
|   bool transpose_b_; | ||||
|   std::vector<string> fused_ops_; | ||||
| }; | ||||
|   const int kInputIndex_Add = 3; | ||||
|   const int kOutputIndex_Dst = 0; | ||||
| };  // namespace tensorflow
 | ||||
| 
 | ||||
| // Register mkl kernels for supported operations and types.
 | ||||
| #define REGISTER_FUSEDMATMUL_MKL_SUPPORTED_KERNELS_TYPES(type)                \ | ||||
|  | ||||
| @ -48,8 +48,9 @@ struct MklDnnMatMulFwdParams { | ||||
|   memory::dims weight_dims; | ||||
|   memory::dims bias_dims; | ||||
|   memory::dims dst_dims; | ||||
|   memory::format_tag src_format; | ||||
|   memory::format_tag weight_format; | ||||
|   MEMORY_FORMAT src_format; | ||||
|   MEMORY_FORMAT weight_format; | ||||
|   MEMORY_FORMAT dst_format; | ||||
|   string dtypes = string(""); | ||||
|   struct PostOpParam { | ||||
|     string name; | ||||
| @ -57,17 +58,18 @@ struct MklDnnMatMulFwdParams { | ||||
|   }; | ||||
|   std::vector<PostOpParam> post_op_params; | ||||
| 
 | ||||
|   MklDnnMatMulFwdParams( | ||||
|       memory::dims src_dims, memory::dims weight_dims, memory::dims bias_dims, | ||||
|       memory::dims dst_dims, | ||||
|       memory::format_tag src_format = memory::format_tag::any, | ||||
|       memory::format_tag weight_format = memory::format_tag::any) | ||||
|   MklDnnMatMulFwdParams(memory::dims src_dims, memory::dims weight_dims, | ||||
|                         memory::dims bias_dims, memory::dims dst_dims, | ||||
|                         MEMORY_FORMAT src_format = MEMORY_FORMAT::any, | ||||
|                         MEMORY_FORMAT weight_format = MEMORY_FORMAT::any, | ||||
|                         MEMORY_FORMAT dst_format = MEMORY_FORMAT::any) | ||||
|       : src_dims(src_dims), | ||||
|         weight_dims(weight_dims), | ||||
|         bias_dims(bias_dims), | ||||
|         dst_dims(dst_dims), | ||||
|         src_format(src_format), | ||||
|         weight_format(weight_format) {} | ||||
|         weight_format(weight_format), | ||||
|         dst_format(dst_format) {} | ||||
| }; | ||||
| 
 | ||||
| // With quantization, input, weight, bias, and output can have different types.
 | ||||
| @ -184,7 +186,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { | ||||
| 
 | ||||
|     context_.dst_md.reset(new memory::desc({matmul_fwd_params.dst_dims}, | ||||
|                                            MklDnnType<Toutput>(), | ||||
|                                            memory::format_tag::any)); | ||||
|                                            matmul_fwd_params.dst_format)); | ||||
| 
 | ||||
|     context_.bias_md.reset(new memory::desc({matmul_fwd_params.bias_dims}, | ||||
|                                             MklDnnType<Tbias>(), | ||||
| @ -236,11 +238,17 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive { | ||||
|           std::vector<float> scales; | ||||
|           scales.push_back(post_op_param.param[0]); | ||||
|           post_ops_attr.set_output_scales(0, scales); | ||||
|         } else if (post_op_param.name == "sum") { | ||||
|           DCHECK_EQ(post_op_param.param.size(), 1); | ||||
|           float op_scale = post_op_param.param[0]; | ||||
|           post_ops.append_sum(op_scale); | ||||
| 
 | ||||
|         } else { | ||||
|           DCHECK((post_op_param.name == "relu") || | ||||
|                  (post_op_param.name == "relu6") || | ||||
|                  (post_op_param.name == "elu") || | ||||
|                  (post_op_param.name == "tanh") || | ||||
|                  (post_op_param.name == "sum") || | ||||
|                  (post_op_param.name == "output_scale")); | ||||
|         } | ||||
|       } | ||||
| @ -340,6 +348,10 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> { | ||||
|         key_creator.AddAsKey(post_op_param.param[0]); | ||||
|         key_creator.AddAsKey(post_op_param.param[1]); | ||||
|         key_creator.AddAsKey(post_op_param.param[2]); | ||||
|       } else if (post_op_param.name == "sum") { | ||||
|         DCHECK_EQ(post_op_param.param.size(), 1); | ||||
|         key_creator.AddAsKey(post_op_param.name); | ||||
|         key_creator.AddAsKey(post_op_param.param[0]); | ||||
|       } else if (post_op_param.name == "output_scale") { | ||||
|         DCHECK_EQ(post_op_param.param.size(), 1); | ||||
|         key_creator.AddAsKey(post_op_param.name); | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user