Merge pull request from Intel-tensorflow:amin/grappler-pattern-matching

PiperOrigin-RevId: 350441156
Change-Id: I5dec478c1e752c2496272848c3920de65e743440
This commit is contained in:
TensorFlower Gardener 2021-01-06 15:51:08 -08:00
commit 5b82ec84a3
4 changed files with 853 additions and 0 deletions

View File

@ -392,6 +392,29 @@ tf_cc_test(
],
)
cc_library(
name = "pattern_utils",
srcs = ["pattern_utils.cc"],
hdrs = ["pattern_utils.h"],
visibility = ["//visibility:public"],
deps = [
":graph_view",
],
)
tf_cc_test(
name = "pattern_utils_test",
srcs = ["pattern_utils_test.cc"],
deps = [
":pattern_utils",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "transitive_fanin",
srcs = ["transitive_fanin.cc"],

View File

@ -0,0 +1,129 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/pattern_utils.h"
namespace tensorflow {
namespace grappler {
namespace utils {
// A subgraph pattern syntax implicitly defines a DAG having a single root. We
// traverse the syntax DAG in DFS manner. This function finds a match for
// current root of the pattern with the current node and recursively matches
// children subpatterns with the children of current node.
template <>
bool SubGraphMatcher<MatchingDirection::kFollowInputs>::DoesOpTypePatternMatch(
const OpTypePattern& pattern, MutableNodeView* node_view,
NodeViewMatch* match) {
// Currently no control inputs and outputs are allowed.
if (node_view->NumControllingFanins() > 0 ||
node_view->NumControlledFanouts() > 0)
return false;
bool op_type_matched = false;
if (pattern.op == "*") {
op_type_matched = true;
} else {
// The op field string of current pattern might express an op among multiple
// op types (mutually exclusive) separated by '|'.
std::vector<string> op_list = str_util::Split(pattern.op, '|');
for (const string& op : op_list) {
if (node_view->node()->op() == op) {
op_type_matched = true;
break;
}
}
}
if (op_type_matched) {
// If op type matches and current node is visited first time, insert current
// node to node_label_to_index_ map with the current label as the key.
// Multiple occurances of same label in the pattern syntax indicates that
// the same node needs to be visited for each of such occurances. Hence
// subsequent visits should find the corresponding label in the map as a key
// and the current node should be the value for that key.
if (node_label_to_index_.find(pattern.label) ==
node_label_to_index_.end()) {
node_label_to_index_[pattern.label] = node_view->node_index();
// Bookkeeping
matched_node_indices_.insert(node_view->node_index());
if (pattern.node_status == NodeStatus::kRemove) {
remove_node_indices_.insert(node_view->node_index());
}
} else if (node_label_to_index_[pattern.label] != node_view->node_index()) {
return false; // label constraint could not be satisfied.
} else {
DCHECK(node_label_to_index_[pattern.label] == node_view->node_index());
}
} else {
return false;
}
// Current root of the pattern syntax is matched with the current node.
match->node_view = node_view;
// Go for matching child subpattern.
if (!pattern.children.empty()) {
// Currently only direction toward inputs is implemented.
auto node_view_children = node_view->GetRegularFanins();
if (node_view_children.size() != pattern.children.size()) {
return false;
} else {
for (int i = 0; i < pattern.children.size(); ++i) {
auto child_node_index = node_view_children[i].node_index();
// TODO (mdfaijul): Is it guaranted that GetNode will reuturn non null
// pointer.
MutableNodeView* child_node_view =
graph_view_->GetNode(child_node_index);
const OpTypePattern& child_pattern = pattern.children[i];
match->children.push_back(NodeViewMatch());
NodeViewMatch* child_match = &(match->children.back());
if (!DoesOpTypePatternMatch(child_pattern, child_node_view,
child_match)) {
return false;
}
}
}
}
return true;
}
// Current implementation supports pattern maching toward node's inputs only.
template <>
bool SubGraphMatcher<MatchingDirection::kFollowInputs>::GetMatchedNodes(
const OpTypePattern& pattern, MutableNodeView* node_view,
std::map<string, int>* matched_nodes_map,
std::set<int>* remove_node_indices) {
bool found_match = false;
match_.reset(new NodeViewMatch());
if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) {
if (!HasRemoveNodeExternalDependents()) {
found_match = true;
matched_nodes_map->swap(this->node_label_to_index_);
remove_node_indices->swap(this->remove_node_indices_);
}
} else {
found_match = false;
// Clear all bookkeeping data
match_->Clear();
match_.reset(nullptr);
node_label_to_index_.clear();
matched_node_indices_.clear();
remove_node_indices_.clear();
}
return found_match;
}
} // namespace utils
} // namespace grappler
} // namespace tensorflow

View File

@ -0,0 +1,227 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
#define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
#include "tensorflow/core/grappler/utils/graph_view.h"
namespace tensorflow {
namespace grappler {
namespace utils {
//------------------------------------------------------------------------------
// A pattern can be defined by the following grammar. Here, op_type is any valid
// op name in the TensorFlow.
//
// leaf_pattern ::= `{` op_type `}`
// pattern ::= leaf_pattern |
// `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}`
//
// (1) For example, the following pattern syntax describes a pattern for
// _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op.
//
// {"Relu",
// {
// "BiasAdd",
// {
// {"Conv2D"},
// {"*"}
// }
// }
// }
//
// The syntax above has a root ("Relu") and children (inputs), where each child
// is a sub-pattern. Graph pattern matcher finds a match for the given pattern
// syntax in a graph and returns a set of matched nodes.
//
// (2) In order to match a DAG with a given root, we extend pattern syntax with
// labels. For example, a frequently found pattern in Deep Learning models is a
// residual block like below.
//
// Placeholder Const
// | |
// +-----+-----+ |
// | | |
// | v v
// | Conv2D Const
// | | |
// | v v-----+
// | BiasAdd
// | |
// v v----------+
// AddV2
//
// As shown above, it is the same input node (Placeholder) consumed by both
// AddV2 and and Conv2D. This constrained can be put as labels in the following
// augmented pattern syntax.
//
// {"AddV2", "my_add",
// {
// {"*", "my_residual_input"},
// {"BiasAdd", "my_bias_add",
// {
// {"Conv2D", "my_conv",
// {
// {"*", "my_residual_input"},
// {"*", "my_filter"}
// }
// },
// {"*", my_bias"}
// }
// }
// }
// }
//
// Note that the same label "my_residual_input" is used to tell that it is a
// child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate
// with the nodes to be matched as well as to uniquely identify those nodes.
//
// (3) The motivatation for a grammar based pattern matching in grappler is to
// make easy for finding fusion pattern in the remapper. A subgraph that
// matches a given pattern, however, is not fusable if any of the matched node,
// that will be removed as a part of fusion, has a consumer outside the matched
// subgraph. In order to check for such type of external dependencies, we
// further extend pattern syntax by prospective action (NodeStatus) on the
// matched nodes as shown below. This helps cross checking the nodes to be
// removed with the nodes matched intially.
//
// {"AddV2", "my_add", NodeStatus::kReplace,
// {
// {"*", "my_residual_input", NodeStatus::kRemain},
// {"BiasAdd", "my_bias_add", NodeStatus::kRemove,
// {
// {"Conv2D", "my_conv", NodeStatus::kRemove,
// {
// {"*", "my_residual_input", NodeStatus::kRemain},
// {"*", "my_filter", NodeStatus::Remain}
// }
// },
// {"*", my_bias", NodeStatus::kRemain}
// }
// }
// }
// }
//------------------------------------------------------------------------------
// Pattern matcher recursively matches child subpatterns. The direction
// for children could be toward node's input (fanins) or outputs (fanouts).
enum class MatchingDirection { kFollowInputs, kFollowOutputs };
// Action for each node in the set of matched nodes for a given pattern.
enum class NodeStatus { kRemain, kRemove, kReplace };
// TODO (intel-tf): Support multiple roots by making them children of a single
// virtual root.
struct OpTypePattern {
string op;
string label;
NodeStatus node_status;
std::vector<OpTypePattern> children;
string DebugString() const {
string result = "{(op: " + op + ", " + "label: " + label + "), {";
for (const OpTypePattern& child : children) {
result += child.DebugString() + ",";
}
result += "}}";
return result;
}
};
// This is a helpful recursive structure that keeps one-to-one mapping of
// pattern syntax to the matched nodes. User can call DebugString to see what
// has been matched so far and where is the failing point.
struct NodeViewMatch {
MutableNodeView* node_view = nullptr;
std::vector<NodeViewMatch> children;
string DebugString() const {
string result = "{";
if (node_view == nullptr) {
result += "Non-Matched-Node}";
return result;
} else {
result += node_view->node()->DebugString();
result += ", {";
for (const NodeViewMatch& child : children) {
result += child.DebugString() + ",";
}
result += "}}";
return result;
}
}
void Clear() {
for (NodeViewMatch& child : children) {
child.Clear(); // child is an object.
}
children.clear(); // children is a vector.
if (node_view != nullptr) {
node_view = nullptr;
}
}
};
template <MatchingDirection DIRECTION = MatchingDirection::kFollowInputs>
class SubGraphMatcher {
public:
SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){};
// If a given pattern is matched, this function returns true as well as the
// matched node and remove node info is populated.
bool GetMatchedNodes(const OpTypePattern& pattern, MutableNodeView* node_view,
std::map<string, int>* matched_nodes_map,
std::set<int>* remove_node_indices);
private:
MutableGraphView* graph_view_;
std::map<string, int> node_label_to_index_;
std::set<int> matched_node_indices_;
std::set<int> remove_node_indices_;
std::unique_ptr<NodeViewMatch> match_ = nullptr;
bool DoesOpTypePatternMatch(const OpTypePattern& pattern,
MutableNodeView* node_view, NodeViewMatch* match);
// This function should be called after the pattern matcher has found
// potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true").
// It performs a sanity check if the candidate nodes for removal in subgraph
// fusion is indeed safe to remove.
bool HasRemoveNodeExternalDependents() {
for (const auto& node_idx : remove_node_indices_) {
auto node_view = graph_view_->GetNode(node_idx);
// Traverse all the Regular Fanouts. Fanouts are stored as vector of
// vector, std::vector<std::vector<MutableFaninView>>. Note that
// a MutableNodeView's fanouts are stored in a nested vector of
// MutableFaninView type.
auto fanouts_by_ports = node_view->GetRegularFanouts();
for (const auto& fanouts : fanouts_by_ports) {
for (const auto& fanout : fanouts) {
if (!matched_node_indices_.count(fanout.node_index())) {
return true;
}
}
}
}
return false;
}
};
} // namespace utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_

View File

@ -0,0 +1,474 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/utils/pattern_utils.h"
#include "tensorflow/cc/ops/nn_ops_internal.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace grappler {
namespace utils {
namespace {
using ::tensorflow::ops::Placeholder;
void GetMatMulBiasAddGeluGraph(GraphDef* graph,
bool add_external_dependent = false) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto input_shape = ops::Placeholder::Shape({8, 32});
auto weight_shape = ops::Placeholder::Shape({32, 64});
auto bias_shape = ops::Placeholder::Shape({64});
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
auto weight = Placeholder(s.WithOpName("weight"), DT_FLOAT, weight_shape);
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, weight);
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
if (add_external_dependent) {
auto external_dependent =
ops::Identity(s.WithOpName("external_dependent"), bias_add);
}
// Gelu with smaller ops
auto one_over_square_root_two =
ops::Const(s.WithOpName("one_over_square_root_two"), {0.707f}, {});
auto bias_add_times_const = ops::Mul(s.WithOpName("bias_add_times_const"),
bias_add, one_over_square_root_two);
auto erf = ops::Erf(s.WithOpName("erf"), bias_add_times_const);
auto one = ops::Const(s.WithOpName("one"), {1.0f}, {});
auto erf_plus_one = ops::AddV2(s.WithOpName("erf_plus_one"), erf, one);
auto one_half = ops::Const(s.WithOpName("one_half"), {0.5f}, {});
auto one_half_times_erf_plus_one = ops::Mul(
s.WithOpName("one_half_times_erf_plus_one"), one_half, erf_plus_one);
auto gelu =
ops::Mul(s.WithOpName("gelu"), one_half_times_erf_plus_one, bias_add);
auto fetch = ops::Identity(s.WithOpName("fetch"), gelu);
TF_ASSERT_OK(s.ToGraphDef(graph));
}
OpTypePattern GetMatMulBiasAddGeluPattern() {
// Although labels are arbitrary, for the convenience of check they are
// prefixed with "my_" to the orginal node names in the global graph.
// clang-format off
OpTypePattern pattern_syntax{"Mul", "my_gelu", NodeStatus::kReplace,
{
{"Mul", "my_one_half_times_erf_plus_one", NodeStatus::kRemove,
{
{"Const", "my_one_half", NodeStatus::kRemain},
{"AddV2", "my_erf_plus_one", NodeStatus::kRemove,
{
{"Erf", "my_erf", NodeStatus::kRemove,
{
{"Mul", "my_bias_add_times_const", NodeStatus::kRemove,
{
{"BiasAdd", "my_bias_add", NodeStatus::kRemove},
{"Const", "my_one_over_square_root_two", NodeStatus::kRemain}
}
}
}
},
{"Const", "my_one", NodeStatus::kRemain}
}
}
}
},
{"BiasAdd", "my_bias_add", NodeStatus::kRemove,
{
{"MatMul", "my_matmul", NodeStatus::kRemove},
{"*", "my_bias", NodeStatus::kRemain}
}
}
}
}; // clang-format on
return pattern_syntax;
}
class PatternMatcherTest : public ::testing::Test {
protected:
struct NodeConfig {
NodeConfig(string name, string op, std::vector<string> inputs)
: name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
string name;
string op;
std::vector<string> inputs;
};
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
GraphDef graph;
for (const NodeConfig& node : nodes) {
NodeDef node_def;
node_def.set_name(node.name);
node_def.set_op(node.op);
for (const string& input : node.inputs) {
node_def.add_input(input);
}
*graph.add_node() = std::move(node_def);
}
return graph;
}
};
TEST_F(PatternMatcherTest, Tree) {
// A Data flow graph. Data flows from top to bottom. Here A, B, C, D, and E
// are ops.
//
// Input graph Subgraph for pattern matcher
//
// A C D
// | \ /
// B E
// /
// C D
// \ /
// E
//
// E is the root of pattern syntax as shown below that the pattern matcher
// would match.
// {"E", "my_e", NodeStatus::kReplace,
// {
// {"C", "my_c", NodeStatus::kRemove}
// {"D", "my_d", NodeStatus::kRemove}
// }
// }
::tensorflow::Status status;
GraphDef graph = CreateGraph({{"e", "E", {"c", "d"}},
{"c", "C", {"b"}},
{"d", "D", {}},
{"b", "B", {"a"}},
{"a", "A", {}}});
// clang-format off
OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
{
{"C", "my_c", NodeStatus::kRemove},
{"D", "my_d", NodeStatus::kRemove}
}
}; // clang-format on
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("e");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
std::map<string, int> matched_nodes_map; // label to node index map
std::set<int> remove_node_indices;
bool found_match = graph_matcher.GetMatchedNodes(
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
EXPECT_TRUE(found_match);
EXPECT_FALSE(matched_nodes_map.empty());
EXPECT_FALSE(remove_node_indices.empty());
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
auto label = str_util::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
all_indices_matched = false;
break;
}
}
EXPECT_TRUE(all_indices_matched);
}
TEST_F(PatternMatcherTest, DAG) {
// A Data flow graph. Data flows from top to bottom. Here A, B, C, D, and E
// are ops.
//
// Input graph Subgraph for pattern matcher
//
// A
// | B
// B / \
// / \ C D
// C D \ /
// \ / E
// E
//
// E is the root of pattern syntax as shown below that the pattern matcher
// would match.
// {"E", "my_e", NodeStatus::kReplace,
// {
// {"C", "my_c", NodeStatus::kRemove,
// {
// {"B", "my_b", NodeStatus::kRemove}
// }
// },
// {"D", "my_d", NodeStatus::kRemove,
// {
// {"B", "my_b", NodeStatus::kRemove}
// }
// }
// }
// }
::tensorflow::Status status;
GraphDef graph = CreateGraph({{"e", "E", {"c", "d"}},
{"c", "C", {"b"}},
{"d", "D", {"b"}},
{"b", "B", {"a"}},
{"a", "A", {}}});
// clang-format off
OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
{
{"C", "my_c", NodeStatus::kRemove,
{
{"B", "my_b", NodeStatus::kRemove}
}
},
{"D", "my_d", NodeStatus::kRemove,
{
{"B", "my_b", NodeStatus::kRemove}
}
}
}
}; // clang-format on
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("e");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
std::map<string, int> matched_nodes_map; // label to node index map
std::set<int> remove_node_indices;
bool found_match = graph_matcher.GetMatchedNodes(
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
EXPECT_TRUE(found_match);
EXPECT_FALSE(matched_nodes_map.empty());
EXPECT_FALSE(remove_node_indices.empty());
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
auto label = str_util::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
all_indices_matched = false;
break;
}
}
EXPECT_TRUE(all_indices_matched);
}
// Pattern should not be matched if any of candidate remove nodes has external
// dependent.
TEST_F(PatternMatcherTest, DAGExternalDependent) {
// A Data flow graph. Data flows from top to bottom. Here A, B, C, D, E, and F
// are ops.
//
// Input graph Subgraph for pattern matcher
//
// A
// | B
// B / \
// / \ C D
// C D \ /
// \ / \ E
// E F
//
// E is the root of pattern syntax as shown below that the pattern matcher
// would match. Note D is a candidate for remove node as mentioned in the
// syntax. So Pattern matcher should not find a match.
// {"E", "my_e", NodeStatus::Replace,
// {
// {"C", "my_c", NodeStatus::kRemove,
// {
// {"B", "my_b", NodeStatus::kRemove}
// }
// },
// {"D", "my_d", NodeStatus::kRemove,
// {
// {"B", "my_b", NodeStatus::kRemove}
// }
// }
// }
// }
::tensorflow::Status status;
GraphDef graph = CreateGraph({{"f", "F", {"d"}},
{"e", "E", {"c", "d"}},
{"c", "C", {"b"}},
{"d", "D", {"b"}},
{"b", "B", {"a"}},
{"a", "A", {}}});
// clang-format off
OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
{
{"C", "my_c", NodeStatus::kRemove,
{
{"B", "my_b", NodeStatus::kRemove}
}
},
{"D", "my_d", NodeStatus::kRemove,
{
{"B", "my_b", NodeStatus::kRemove}
}
}
}
}; // clang-format on
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("e");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
std::map<string, int> matched_nodes_map; // label to node index map
std::set<int> remove_node_indices;
bool found_match = graph_matcher.GetMatchedNodes(
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
EXPECT_FALSE(found_match);
EXPECT_TRUE(matched_nodes_map.empty());
EXPECT_TRUE(remove_node_indices.empty());
}
TEST_F(PatternMatcherTest, MatMulBiasAddGelu) {
::tensorflow::Status status;
GraphDef graph;
GetMatMulBiasAddGeluGraph(&graph);
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("gelu");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
std::map<string, int> matched_nodes_map; // label to node index map
std::set<int> remove_node_indices;
bool found_match = graph_matcher.GetMatchedNodes(
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
EXPECT_TRUE(found_match);
EXPECT_FALSE(matched_nodes_map.empty());
EXPECT_FALSE(remove_node_indices.empty());
bool all_indices_matched = true;
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
it++) {
auto label = str_util::StripPrefix(it->first, "my_");
int matched_node_idx = it->second;
int expected_node_idx = graph_view.GetNode(label)->node_index();
if (matched_node_idx != expected_node_idx) {
all_indices_matched = false;
break;
}
}
EXPECT_TRUE(all_indices_matched);
}
// Pattern should not be matched if any of candidate remove nodes has external
// dependent.
TEST_F(PatternMatcherTest, MatMulBiasAddGeluExternalDependent) {
::tensorflow::Status status;
GraphDef graph;
GetMatMulBiasAddGeluGraph(&graph, /*add_external_dependent=*/true);
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("gelu");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
std::map<string, int> matched_nodes_map; // label to node index map
std::set<int> remove_node_indices;
bool found_match = graph_matcher.GetMatchedNodes(
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
EXPECT_FALSE(found_match);
EXPECT_TRUE(matched_nodes_map.empty());
EXPECT_TRUE(remove_node_indices.empty());
}
TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
::tensorflow::Status status;
GraphDef graph;
GetMatMulBiasAddGeluGraph(&graph);
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
MutableGraphView graph_view(&graph, &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
auto root_node_view = graph_view.GetNode("gelu");
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
std::map<string, int> matched_nodes_map; // label to node index map
std::set<int> remove_node_indices;
bool found_match = graph_matcher.GetMatchedNodes(
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
EXPECT_TRUE(found_match);
EXPECT_FALSE(matched_nodes_map.empty());
EXPECT_FALSE(remove_node_indices.empty());
// Before mutation number of nodes.
int num_nodes_before = graph_view.NumNodes();
// Before mutation node_names of the remove candidate nodes.
std::vector<string> remove_node_names;
for (auto const& node_idx : remove_node_indices) {
remove_node_names.push_back(graph_view.GetNode(node_idx)->GetName());
}
Mutation* mutation = graph_view.GetMutationBuilder();
// Replace with fused op.
NodeDef fused_node;
fused_node.set_name("gelu");
fused_node.set_op("_FusedMatMul");
fused_node.add_input(graph_view.GetNode("matmul")->node()->input(0));
fused_node.add_input(graph_view.GetNode("matmul")->node()->input(1));
fused_node.add_input(graph_view.GetNode("bias_add")->node()->input(1));
mutation->AddNode(std::move(fused_node), &status);
TF_ASSERT_OK(status);
TF_EXPECT_OK(mutation->Apply());
// Remove nodes that are marked as NodeStatus::kRemove.
for (auto const& node_idx : remove_node_indices) {
mutation->RemoveNode(graph_view.GetNode(node_idx));
}
TF_EXPECT_OK(mutation->Apply());
// After mutation number of nodes.
int num_nodes_after = graph_view.NumNodes();
EXPECT_EQ(num_nodes_before - remove_node_indices.size(), num_nodes_after);
bool remove_nodes_deleted = true;
for (auto const& node_name : remove_node_names) {
if (graph_view.GetNode(node_name) != nullptr) {
remove_nodes_deleted = false;
break;
}
}
EXPECT_TRUE(remove_nodes_deleted);
bool replace_node_exist = graph_view.HasNode("gelu") ? true : false;
EXPECT_TRUE(replace_node_exist);
}
} // namespace
} // namespace utils
} // namespace grappler
} // namespace tensorflow