[Grappler] Set simplified node name for new node created when fusing transposes into MatMul.

PiperOrigin-RevId: 242541256
This commit is contained in:
Andy Ly 2019-04-08 14:37:54 -07:00 committed by TensorFlower Gardener
parent 401ab92175
commit 73ae2a64c7
2 changed files with 13 additions and 4 deletions

View File

@ -2231,6 +2231,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
if (a_is_foldable) deps_to_forward.push_back(a);
if (b_is_foldable) deps_to_forward.push_back(b);
ForwardControlDependencies(new_op, deps_to_forward);
*simplified_node_name = new_op->name();
return Status::OK();
}

View File

@ -805,15 +805,18 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
Output matmul;
auto matmul_op = s.WithOpName("matmul");
if (matmul_type == "MatMul") {
Output matmul = ops::MatMul(matmul_op, trans_a, trans_b);
matmul = ops::MatMul(matmul_op, trans_a, trans_b);
} else if (matmul_type == "SparseMatMul") {
Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
} else if (matmul_type == "BatchMatMul") {
Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
}
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
GrapplerItem item;
item.fetch = {"matmul"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@ -827,7 +830,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
OptimizeTwice(&optimizer, &item, &output);
NodeMap node_map(&output);
EXPECT_EQ(7, output.node_size());
EXPECT_EQ(8, output.node_size());
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
const string optimized_name = strings::StrCat(p, "_", "matmul");
@ -845,6 +848,11 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
}
const NodeDef* identity_node = node_map.GetNode("identity");
ASSERT_NE(identity_node, nullptr);
ASSERT_EQ(identity_node->input_size(), 1);
EXPECT_EQ(identity_node->input(0), optimized_name);
auto tensors = EvaluateNodes(output, item.fetch);
EXPECT_EQ(1, tensors.size());
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);