diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index ba716fc859d..7dc62fe54c1 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -2231,6 +2231,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage { if (a_is_foldable) deps_to_forward.push_back(a); if (b_is_foldable) deps_to_forward.push_back(b); ForwardControlDependencies(new_op, deps_to_forward); + *simplified_node_name = new_op->name(); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 532af54ff24..6a9d7557620 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -805,15 +805,18 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm); Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm); + Output matmul; auto matmul_op = s.WithOpName("matmul"); if (matmul_type == "MatMul") { - Output matmul = ops::MatMul(matmul_op, trans_a, trans_b); + matmul = ops::MatMul(matmul_op, trans_a, trans_b); } else if (matmul_type == "SparseMatMul") { - Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b); + matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b); } else if (matmul_type == "BatchMatMul") { - Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b); + matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b); } + auto identity = ops::Identity(s.WithOpName("identity"), matmul); + GrapplerItem item; item.fetch = {"matmul"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -827,7 +830,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { OptimizeTwice(&optimizer, &item, &output); NodeMap node_map(&output); - EXPECT_EQ(7, output.node_size()); + EXPECT_EQ(8, output.node_size()); const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul"; const string optimized_name = strings::StrCat(p, "_", "matmul"); @@ -845,6 +848,11 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) { EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b()); } + const NodeDef* identity_node = node_map.GetNode("identity"); + ASSERT_NE(identity_node, nullptr); + ASSERT_EQ(identity_node->input_size(), 1); + EXPECT_EQ(identity_node->input(0), optimized_name); + auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(1, tensors.size()); test::ExpectTensorNear(tensors_expected[0], tensors[0], 1e-6);