[Grappler] Set simplified node name for new node created when fusing transposes into MatMul.

PiperOrigin-RevId: 242541256
This commit is contained in:
Andy Ly 2019-04-08 14:37:54 -07:00 committed by TensorFlower Gardener
parent 401ab92175
commit 73ae2a64c7
2 changed files with 13 additions and 4 deletions

View File

@ -2231,6 +2231,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
if (a_is_foldable) deps_to_forward.push_back(a); if (a_is_foldable) deps_to_forward.push_back(a);
if (b_is_foldable) deps_to_forward.push_back(b); if (b_is_foldable) deps_to_forward.push_back(b);
ForwardControlDependencies(new_op, deps_to_forward); ForwardControlDependencies(new_op, deps_to_forward);
*simplified_node_name = new_op->name();
return Status::OK(); return Status::OK();
} }

View File

@ -805,15 +805,18 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm); Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm); Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
Output matmul;
auto matmul_op = s.WithOpName("matmul"); auto matmul_op = s.WithOpName("matmul");
if (matmul_type == "MatMul") { if (matmul_type == "MatMul") {
Output matmul = ops::MatMul(matmul_op, trans_a, trans_b); matmul = ops::MatMul(matmul_op, trans_a, trans_b);
} else if (matmul_type == "SparseMatMul") { } else if (matmul_type == "SparseMatMul") {
Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b); matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
} else if (matmul_type == "BatchMatMul") { } else if (matmul_type == "BatchMatMul") {
Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b); matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
} }
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
GrapplerItem item; GrapplerItem item;
item.fetch = {"matmul"}; item.fetch = {"matmul"};
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -827,7 +830,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
OptimizeTwice(&optimizer, &item, &output); OptimizeTwice(&optimizer, &item, &output);
NodeMap node_map(&output); NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size()); EXPECT_EQ(8, output.node_size());
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul"; const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
const string optimized_name = strings::StrCat(p, "_", "matmul"); const string optimized_name = strings::StrCat(p, "_", "matmul");
@ -845,6 +848,11 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b()); EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
} }
const NodeDef* identity_node = node_map.GetNode("identity");
ASSERT_NE(identity_node, nullptr);
ASSERT_EQ(identity_node->input_size(), 1);
EXPECT_EQ(identity_node->input(0), optimized_name);
auto tensors = EvaluateNodes(output, item.fetch); auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size()); EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);