[Grappler] Set simplified node name for new node created when fusing transposes into MatMul.
PiperOrigin-RevId: 242541256
This commit is contained in:
parent
401ab92175
commit
73ae2a64c7
@ -2231,6 +2231,7 @@ class FoldTransposeIntoMatMul : public ArithmeticOptimizerStage {
|
|||||||
if (a_is_foldable) deps_to_forward.push_back(a);
|
if (a_is_foldable) deps_to_forward.push_back(a);
|
||||||
if (b_is_foldable) deps_to_forward.push_back(b);
|
if (b_is_foldable) deps_to_forward.push_back(b);
|
||||||
ForwardControlDependencies(new_op, deps_to_forward);
|
ForwardControlDependencies(new_op, deps_to_forward);
|
||||||
|
*simplified_node_name = new_op->name();
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -805,15 +805,18 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
|
|||||||
Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
|
Output trans_a = ops::Transpose(s.WithOpName("trans_a"), a, perm);
|
||||||
Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
|
Output trans_b = ops::Transpose(s.WithOpName("trans_b"), b, perm);
|
||||||
|
|
||||||
|
Output matmul;
|
||||||
auto matmul_op = s.WithOpName("matmul");
|
auto matmul_op = s.WithOpName("matmul");
|
||||||
if (matmul_type == "MatMul") {
|
if (matmul_type == "MatMul") {
|
||||||
Output matmul = ops::MatMul(matmul_op, trans_a, trans_b);
|
matmul = ops::MatMul(matmul_op, trans_a, trans_b);
|
||||||
} else if (matmul_type == "SparseMatMul") {
|
} else if (matmul_type == "SparseMatMul") {
|
||||||
Output matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
|
matmul = ops::SparseMatMul(matmul_op, trans_a, trans_b);
|
||||||
} else if (matmul_type == "BatchMatMul") {
|
} else if (matmul_type == "BatchMatMul") {
|
||||||
Output matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
|
matmul = ops::BatchMatMul(matmul_op, trans_a, trans_b);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto identity = ops::Identity(s.WithOpName("identity"), matmul);
|
||||||
|
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
item.fetch = {"matmul"};
|
item.fetch = {"matmul"};
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
@ -827,7 +830,7 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
|
|||||||
OptimizeTwice(&optimizer, &item, &output);
|
OptimizeTwice(&optimizer, &item, &output);
|
||||||
NodeMap node_map(&output);
|
NodeMap node_map(&output);
|
||||||
|
|
||||||
EXPECT_EQ(7, output.node_size());
|
EXPECT_EQ(8, output.node_size());
|
||||||
|
|
||||||
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
|
const string p = "ArithmeticOptimizer/FoldTransposeIntoMatMul";
|
||||||
const string optimized_name = strings::StrCat(p, "_", "matmul");
|
const string optimized_name = strings::StrCat(p, "_", "matmul");
|
||||||
@ -845,6 +848,11 @@ TEST_F(ArithmeticOptimizerTest, FoldTransposeIntoMatMul) {
|
|||||||
EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
|
EXPECT_TRUE(matmul_fused_node->attr().at("transpose_b").b());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const NodeDef* identity_node = node_map.GetNode("identity");
|
||||||
|
ASSERT_NE(identity_node, nullptr);
|
||||||
|
ASSERT_EQ(identity_node->input_size(), 1);
|
||||||
|
EXPECT_EQ(identity_node->input(0), optimized_name);
|
||||||
|
|
||||||
auto tensors = EvaluateNodes(output, item.fetch);
|
auto tensors = EvaluateNodes(output, item.fetch);
|
||||||
EXPECT_EQ(1, tensors.size());
|
EXPECT_EQ(1, tensors.size());
|
||||||
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
|
||||||
|
Loading…
Reference in New Issue
Block a user