Merge pull request #45850 from Intel-tensorflow:amin/grappler-pattern-matching
PiperOrigin-RevId: 350441156 Change-Id: I5dec478c1e752c2496272848c3920de65e743440
This commit is contained in:
commit
5b82ec84a3
tensorflow/core/grappler/utils
@ -392,6 +392,29 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "pattern_utils",
|
||||
srcs = ["pattern_utils.cc"],
|
||||
hdrs = ["pattern_utils.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_view",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "pattern_utils_test",
|
||||
srcs = ["pattern_utils_test.cc"],
|
||||
deps = [
|
||||
":pattern_utils",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "transitive_fanin",
|
||||
srcs = ["transitive_fanin.cc"],
|
||||
|
129
tensorflow/core/grappler/utils/pattern_utils.cc
Normal file
129
tensorflow/core/grappler/utils/pattern_utils.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/pattern_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace utils {
|
||||
|
||||
// A subgraph pattern syntax implicitly defines a DAG having a single root. We
|
||||
// traverse the syntax DAG in DFS manner. This function finds a match for
|
||||
// current root of the pattern with the current node and recursively matches
|
||||
// children subpatterns with the children of current node.
|
||||
template <>
|
||||
bool SubGraphMatcher<MatchingDirection::kFollowInputs>::DoesOpTypePatternMatch(
|
||||
const OpTypePattern& pattern, MutableNodeView* node_view,
|
||||
NodeViewMatch* match) {
|
||||
// Currently no control inputs and outputs are allowed.
|
||||
if (node_view->NumControllingFanins() > 0 ||
|
||||
node_view->NumControlledFanouts() > 0)
|
||||
return false;
|
||||
|
||||
bool op_type_matched = false;
|
||||
if (pattern.op == "*") {
|
||||
op_type_matched = true;
|
||||
} else {
|
||||
// The op field string of current pattern might express an op among multiple
|
||||
// op types (mutually exclusive) separated by '|'.
|
||||
std::vector<string> op_list = str_util::Split(pattern.op, '|');
|
||||
for (const string& op : op_list) {
|
||||
if (node_view->node()->op() == op) {
|
||||
op_type_matched = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (op_type_matched) {
|
||||
// If op type matches and current node is visited first time, insert current
|
||||
// node to node_label_to_index_ map with the current label as the key.
|
||||
// Multiple occurances of same label in the pattern syntax indicates that
|
||||
// the same node needs to be visited for each of such occurances. Hence
|
||||
// subsequent visits should find the corresponding label in the map as a key
|
||||
// and the current node should be the value for that key.
|
||||
if (node_label_to_index_.find(pattern.label) ==
|
||||
node_label_to_index_.end()) {
|
||||
node_label_to_index_[pattern.label] = node_view->node_index();
|
||||
// Bookkeeping
|
||||
matched_node_indices_.insert(node_view->node_index());
|
||||
if (pattern.node_status == NodeStatus::kRemove) {
|
||||
remove_node_indices_.insert(node_view->node_index());
|
||||
}
|
||||
} else if (node_label_to_index_[pattern.label] != node_view->node_index()) {
|
||||
return false; // label constraint could not be satisfied.
|
||||
} else {
|
||||
DCHECK(node_label_to_index_[pattern.label] == node_view->node_index());
|
||||
}
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
// Current root of the pattern syntax is matched with the current node.
|
||||
match->node_view = node_view;
|
||||
|
||||
// Go for matching child subpattern.
|
||||
if (!pattern.children.empty()) {
|
||||
// Currently only direction toward inputs is implemented.
|
||||
auto node_view_children = node_view->GetRegularFanins();
|
||||
if (node_view_children.size() != pattern.children.size()) {
|
||||
return false;
|
||||
} else {
|
||||
for (int i = 0; i < pattern.children.size(); ++i) {
|
||||
auto child_node_index = node_view_children[i].node_index();
|
||||
// TODO (mdfaijul): Is it guaranted that GetNode will reuturn non null
|
||||
// pointer.
|
||||
MutableNodeView* child_node_view =
|
||||
graph_view_->GetNode(child_node_index);
|
||||
const OpTypePattern& child_pattern = pattern.children[i];
|
||||
match->children.push_back(NodeViewMatch());
|
||||
NodeViewMatch* child_match = &(match->children.back());
|
||||
if (!DoesOpTypePatternMatch(child_pattern, child_node_view,
|
||||
child_match)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// Current implementation supports pattern maching toward node's inputs only.
|
||||
template <>
|
||||
bool SubGraphMatcher<MatchingDirection::kFollowInputs>::GetMatchedNodes(
|
||||
const OpTypePattern& pattern, MutableNodeView* node_view,
|
||||
std::map<string, int>* matched_nodes_map,
|
||||
std::set<int>* remove_node_indices) {
|
||||
bool found_match = false;
|
||||
match_.reset(new NodeViewMatch());
|
||||
if (DoesOpTypePatternMatch(pattern, node_view, match_.get())) {
|
||||
if (!HasRemoveNodeExternalDependents()) {
|
||||
found_match = true;
|
||||
matched_nodes_map->swap(this->node_label_to_index_);
|
||||
remove_node_indices->swap(this->remove_node_indices_);
|
||||
}
|
||||
} else {
|
||||
found_match = false;
|
||||
// Clear all bookkeeping data
|
||||
match_->Clear();
|
||||
match_.reset(nullptr);
|
||||
node_label_to_index_.clear();
|
||||
matched_node_indices_.clear();
|
||||
remove_node_indices_.clear();
|
||||
}
|
||||
return found_match;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
227
tensorflow/core/grappler/utils/pattern_utils.h
Normal file
227
tensorflow/core/grappler/utils/pattern_utils.h
Normal file
@ -0,0 +1,227 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
|
||||
|
||||
#include "tensorflow/core/grappler/utils/graph_view.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace utils {
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
// A pattern can be defined by the following grammar. Here, op_type is any valid
|
||||
// op name in the TensorFlow.
|
||||
//
|
||||
// leaf_pattern ::= `{` op_type `}`
|
||||
// pattern ::= leaf_pattern |
|
||||
// `{` op_type `,` `{` pattern `,` ... `,` pattern `}` `}`
|
||||
//
|
||||
// (1) For example, the following pattern syntax describes a pattern for
|
||||
// _FusedConv2D (Conv2D + BiasAdd + Relu). Note that "*" means any type of op.
|
||||
//
|
||||
// {"Relu",
|
||||
// {
|
||||
// "BiasAdd",
|
||||
// {
|
||||
// {"Conv2D"},
|
||||
// {"*"}
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// The syntax above has a root ("Relu") and children (inputs), where each child
|
||||
// is a sub-pattern. Graph pattern matcher finds a match for the given pattern
|
||||
// syntax in a graph and returns a set of matched nodes.
|
||||
//
|
||||
// (2) In order to match a DAG with a given root, we extend pattern syntax with
|
||||
// labels. For example, a frequently found pattern in Deep Learning models is a
|
||||
// residual block like below.
|
||||
//
|
||||
// Placeholder Const
|
||||
// | |
|
||||
// +-----+-----+ |
|
||||
// | | |
|
||||
// | v v
|
||||
// | Conv2D Const
|
||||
// | | |
|
||||
// | v v-----+
|
||||
// | BiasAdd
|
||||
// | |
|
||||
// v v----------+
|
||||
// AddV2
|
||||
//
|
||||
// As shown above, it is the same input node (Placeholder) consumed by both
|
||||
// AddV2 and and Conv2D. This constrained can be put as labels in the following
|
||||
// augmented pattern syntax.
|
||||
//
|
||||
// {"AddV2", "my_add",
|
||||
// {
|
||||
// {"*", "my_residual_input"},
|
||||
// {"BiasAdd", "my_bias_add",
|
||||
// {
|
||||
// {"Conv2D", "my_conv",
|
||||
// {
|
||||
// {"*", "my_residual_input"},
|
||||
// {"*", "my_filter"}
|
||||
// }
|
||||
// },
|
||||
// {"*", my_bias"}
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// Note that the same label "my_residual_input" is used to tell that it is a
|
||||
// child of both "AddV2" and "Conv2D". Labels are arbitrary strings to associate
|
||||
// with the nodes to be matched as well as to uniquely identify those nodes.
|
||||
//
|
||||
// (3) The motivatation for a grammar based pattern matching in grappler is to
|
||||
// make easy for finding fusion pattern in the remapper. A subgraph that
|
||||
// matches a given pattern, however, is not fusable if any of the matched node,
|
||||
// that will be removed as a part of fusion, has a consumer outside the matched
|
||||
// subgraph. In order to check for such type of external dependencies, we
|
||||
// further extend pattern syntax by prospective action (NodeStatus) on the
|
||||
// matched nodes as shown below. This helps cross checking the nodes to be
|
||||
// removed with the nodes matched intially.
|
||||
//
|
||||
// {"AddV2", "my_add", NodeStatus::kReplace,
|
||||
// {
|
||||
// {"*", "my_residual_input", NodeStatus::kRemain},
|
||||
// {"BiasAdd", "my_bias_add", NodeStatus::kRemove,
|
||||
// {
|
||||
// {"Conv2D", "my_conv", NodeStatus::kRemove,
|
||||
// {
|
||||
// {"*", "my_residual_input", NodeStatus::kRemain},
|
||||
// {"*", "my_filter", NodeStatus::Remain}
|
||||
// }
|
||||
// },
|
||||
// {"*", my_bias", NodeStatus::kRemain}
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//------------------------------------------------------------------------------
|
||||
|
||||
// Pattern matcher recursively matches child subpatterns. The direction
|
||||
// for children could be toward node's input (fanins) or outputs (fanouts).
|
||||
enum class MatchingDirection { kFollowInputs, kFollowOutputs };
|
||||
|
||||
// Action for each node in the set of matched nodes for a given pattern.
|
||||
enum class NodeStatus { kRemain, kRemove, kReplace };
|
||||
|
||||
// TODO (intel-tf): Support multiple roots by making them children of a single
|
||||
// virtual root.
|
||||
struct OpTypePattern {
|
||||
string op;
|
||||
string label;
|
||||
NodeStatus node_status;
|
||||
std::vector<OpTypePattern> children;
|
||||
|
||||
string DebugString() const {
|
||||
string result = "{(op: " + op + ", " + "label: " + label + "), {";
|
||||
for (const OpTypePattern& child : children) {
|
||||
result += child.DebugString() + ",";
|
||||
}
|
||||
result += "}}";
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
// This is a helpful recursive structure that keeps one-to-one mapping of
|
||||
// pattern syntax to the matched nodes. User can call DebugString to see what
|
||||
// has been matched so far and where is the failing point.
|
||||
struct NodeViewMatch {
|
||||
MutableNodeView* node_view = nullptr;
|
||||
std::vector<NodeViewMatch> children;
|
||||
|
||||
string DebugString() const {
|
||||
string result = "{";
|
||||
if (node_view == nullptr) {
|
||||
result += "Non-Matched-Node}";
|
||||
return result;
|
||||
} else {
|
||||
result += node_view->node()->DebugString();
|
||||
result += ", {";
|
||||
for (const NodeViewMatch& child : children) {
|
||||
result += child.DebugString() + ",";
|
||||
}
|
||||
result += "}}";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
void Clear() {
|
||||
for (NodeViewMatch& child : children) {
|
||||
child.Clear(); // child is an object.
|
||||
}
|
||||
children.clear(); // children is a vector.
|
||||
if (node_view != nullptr) {
|
||||
node_view = nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <MatchingDirection DIRECTION = MatchingDirection::kFollowInputs>
|
||||
class SubGraphMatcher {
|
||||
public:
|
||||
SubGraphMatcher(MutableGraphView* graph_view) : graph_view_(graph_view){};
|
||||
|
||||
// If a given pattern is matched, this function returns true as well as the
|
||||
// matched node and remove node info is populated.
|
||||
bool GetMatchedNodes(const OpTypePattern& pattern, MutableNodeView* node_view,
|
||||
std::map<string, int>* matched_nodes_map,
|
||||
std::set<int>* remove_node_indices);
|
||||
|
||||
private:
|
||||
MutableGraphView* graph_view_;
|
||||
std::map<string, int> node_label_to_index_;
|
||||
std::set<int> matched_node_indices_;
|
||||
std::set<int> remove_node_indices_;
|
||||
std::unique_ptr<NodeViewMatch> match_ = nullptr;
|
||||
|
||||
bool DoesOpTypePatternMatch(const OpTypePattern& pattern,
|
||||
MutableNodeView* node_view, NodeViewMatch* match);
|
||||
|
||||
// This function should be called after the pattern matcher has found
|
||||
// potential matched nodes (i.e. when DoesOpTypePatternMatch returns "true").
|
||||
// It performs a sanity check if the candidate nodes for removal in subgraph
|
||||
// fusion is indeed safe to remove.
|
||||
bool HasRemoveNodeExternalDependents() {
|
||||
for (const auto& node_idx : remove_node_indices_) {
|
||||
auto node_view = graph_view_->GetNode(node_idx);
|
||||
// Traverse all the Regular Fanouts. Fanouts are stored as vector of
|
||||
// vector, std::vector<std::vector<MutableFaninView>>. Note that
|
||||
// a MutableNodeView's fanouts are stored in a nested vector of
|
||||
// MutableFaninView type.
|
||||
auto fanouts_by_ports = node_view->GetRegularFanouts();
|
||||
for (const auto& fanouts : fanouts_by_ports) {
|
||||
for (const auto& fanout : fanouts) {
|
||||
if (!matched_node_indices_.count(fanout.node_index())) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_PATTERN_HELPER_H_
|
474
tensorflow/core/grappler/utils/pattern_utils_test.cc
Normal file
474
tensorflow/core/grappler/utils/pattern_utils_test.cc
Normal file
@ -0,0 +1,474 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/utils/pattern_utils.h"
|
||||
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace utils {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
void GetMatMulBiasAddGeluGraph(GraphDef* graph,
|
||||
bool add_external_dependent = false) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
auto input_shape = ops::Placeholder::Shape({8, 32});
|
||||
auto weight_shape = ops::Placeholder::Shape({32, 64});
|
||||
auto bias_shape = ops::Placeholder::Shape({64});
|
||||
|
||||
auto input = Placeholder(s.WithOpName("input"), DT_FLOAT, input_shape);
|
||||
auto weight = Placeholder(s.WithOpName("weight"), DT_FLOAT, weight_shape);
|
||||
auto bias = Placeholder(s.WithOpName("bias"), DT_FLOAT, bias_shape);
|
||||
|
||||
auto matmul = ops::MatMul(s.WithOpName("matmul"), input, weight);
|
||||
auto bias_add = ops::BiasAdd(s.WithOpName("bias_add"), matmul, bias);
|
||||
if (add_external_dependent) {
|
||||
auto external_dependent =
|
||||
ops::Identity(s.WithOpName("external_dependent"), bias_add);
|
||||
}
|
||||
// Gelu with smaller ops
|
||||
auto one_over_square_root_two =
|
||||
ops::Const(s.WithOpName("one_over_square_root_two"), {0.707f}, {});
|
||||
auto bias_add_times_const = ops::Mul(s.WithOpName("bias_add_times_const"),
|
||||
bias_add, one_over_square_root_two);
|
||||
auto erf = ops::Erf(s.WithOpName("erf"), bias_add_times_const);
|
||||
auto one = ops::Const(s.WithOpName("one"), {1.0f}, {});
|
||||
auto erf_plus_one = ops::AddV2(s.WithOpName("erf_plus_one"), erf, one);
|
||||
auto one_half = ops::Const(s.WithOpName("one_half"), {0.5f}, {});
|
||||
auto one_half_times_erf_plus_one = ops::Mul(
|
||||
s.WithOpName("one_half_times_erf_plus_one"), one_half, erf_plus_one);
|
||||
auto gelu =
|
||||
ops::Mul(s.WithOpName("gelu"), one_half_times_erf_plus_one, bias_add);
|
||||
auto fetch = ops::Identity(s.WithOpName("fetch"), gelu);
|
||||
|
||||
TF_ASSERT_OK(s.ToGraphDef(graph));
|
||||
}
|
||||
|
||||
OpTypePattern GetMatMulBiasAddGeluPattern() {
|
||||
// Although labels are arbitrary, for the convenience of check they are
|
||||
// prefixed with "my_" to the orginal node names in the global graph.
|
||||
// clang-format off
|
||||
OpTypePattern pattern_syntax{"Mul", "my_gelu", NodeStatus::kReplace,
|
||||
{
|
||||
{"Mul", "my_one_half_times_erf_plus_one", NodeStatus::kRemove,
|
||||
{
|
||||
{"Const", "my_one_half", NodeStatus::kRemain},
|
||||
{"AddV2", "my_erf_plus_one", NodeStatus::kRemove,
|
||||
{
|
||||
{"Erf", "my_erf", NodeStatus::kRemove,
|
||||
{
|
||||
{"Mul", "my_bias_add_times_const", NodeStatus::kRemove,
|
||||
{
|
||||
{"BiasAdd", "my_bias_add", NodeStatus::kRemove},
|
||||
{"Const", "my_one_over_square_root_two", NodeStatus::kRemain}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"Const", "my_one", NodeStatus::kRemain}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{"BiasAdd", "my_bias_add", NodeStatus::kRemove,
|
||||
{
|
||||
{"MatMul", "my_matmul", NodeStatus::kRemove},
|
||||
{"*", "my_bias", NodeStatus::kRemain}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // clang-format on
|
||||
|
||||
return pattern_syntax;
|
||||
}
|
||||
|
||||
class PatternMatcherTest : public ::testing::Test {
|
||||
protected:
|
||||
struct NodeConfig {
|
||||
NodeConfig(string name, string op, std::vector<string> inputs)
|
||||
: name(std::move(name)), op(std::move(op)), inputs(std::move(inputs)) {}
|
||||
|
||||
string name;
|
||||
string op;
|
||||
std::vector<string> inputs;
|
||||
};
|
||||
|
||||
static GraphDef CreateGraph(const std::vector<NodeConfig>& nodes) {
|
||||
GraphDef graph;
|
||||
|
||||
for (const NodeConfig& node : nodes) {
|
||||
NodeDef node_def;
|
||||
node_def.set_name(node.name);
|
||||
node_def.set_op(node.op);
|
||||
for (const string& input : node.inputs) {
|
||||
node_def.add_input(input);
|
||||
}
|
||||
*graph.add_node() = std::move(node_def);
|
||||
}
|
||||
|
||||
return graph;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(PatternMatcherTest, Tree) {
|
||||
// A Data flow graph. Data flows from top to bottom. Here A, B, C, D, and E
|
||||
// are ops.
|
||||
//
|
||||
// Input graph Subgraph for pattern matcher
|
||||
//
|
||||
// A C D
|
||||
// | \ /
|
||||
// B E
|
||||
// /
|
||||
// C D
|
||||
// \ /
|
||||
// E
|
||||
//
|
||||
// E is the root of pattern syntax as shown below that the pattern matcher
|
||||
// would match.
|
||||
// {"E", "my_e", NodeStatus::kReplace,
|
||||
// {
|
||||
// {"C", "my_c", NodeStatus::kRemove}
|
||||
// {"D", "my_d", NodeStatus::kRemove}
|
||||
// }
|
||||
// }
|
||||
|
||||
::tensorflow::Status status;
|
||||
GraphDef graph = CreateGraph({{"e", "E", {"c", "d"}},
|
||||
{"c", "C", {"b"}},
|
||||
{"d", "D", {}},
|
||||
{"b", "B", {"a"}},
|
||||
{"a", "A", {}}});
|
||||
// clang-format off
|
||||
OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
|
||||
{
|
||||
{"C", "my_c", NodeStatus::kRemove},
|
||||
{"D", "my_d", NodeStatus::kRemove}
|
||||
}
|
||||
}; // clang-format on
|
||||
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("e");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
std::map<string, int> matched_nodes_map; // label to node index map
|
||||
std::set<int> remove_node_indices;
|
||||
bool found_match = graph_matcher.GetMatchedNodes(
|
||||
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
|
||||
|
||||
EXPECT_TRUE(found_match);
|
||||
EXPECT_FALSE(matched_nodes_map.empty());
|
||||
EXPECT_FALSE(remove_node_indices.empty());
|
||||
|
||||
bool all_indices_matched = true;
|
||||
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
|
||||
it++) {
|
||||
auto label = str_util::StripPrefix(it->first, "my_");
|
||||
int matched_node_idx = it->second;
|
||||
int expected_node_idx = graph_view.GetNode(label)->node_index();
|
||||
if (matched_node_idx != expected_node_idx) {
|
||||
all_indices_matched = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(all_indices_matched);
|
||||
}
|
||||
|
||||
TEST_F(PatternMatcherTest, DAG) {
|
||||
// A Data flow graph. Data flows from top to bottom. Here A, B, C, D, and E
|
||||
// are ops.
|
||||
//
|
||||
// Input graph Subgraph for pattern matcher
|
||||
//
|
||||
// A
|
||||
// | B
|
||||
// B / \
|
||||
// / \ C D
|
||||
// C D \ /
|
||||
// \ / E
|
||||
// E
|
||||
//
|
||||
// E is the root of pattern syntax as shown below that the pattern matcher
|
||||
// would match.
|
||||
// {"E", "my_e", NodeStatus::kReplace,
|
||||
// {
|
||||
// {"C", "my_c", NodeStatus::kRemove,
|
||||
// {
|
||||
// {"B", "my_b", NodeStatus::kRemove}
|
||||
// }
|
||||
// },
|
||||
// {"D", "my_d", NodeStatus::kRemove,
|
||||
// {
|
||||
// {"B", "my_b", NodeStatus::kRemove}
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
::tensorflow::Status status;
|
||||
GraphDef graph = CreateGraph({{"e", "E", {"c", "d"}},
|
||||
{"c", "C", {"b"}},
|
||||
{"d", "D", {"b"}},
|
||||
{"b", "B", {"a"}},
|
||||
{"a", "A", {}}});
|
||||
// clang-format off
|
||||
OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
|
||||
{
|
||||
{"C", "my_c", NodeStatus::kRemove,
|
||||
{
|
||||
{"B", "my_b", NodeStatus::kRemove}
|
||||
}
|
||||
},
|
||||
{"D", "my_d", NodeStatus::kRemove,
|
||||
{
|
||||
{"B", "my_b", NodeStatus::kRemove}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // clang-format on
|
||||
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("e");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
std::map<string, int> matched_nodes_map; // label to node index map
|
||||
std::set<int> remove_node_indices;
|
||||
bool found_match = graph_matcher.GetMatchedNodes(
|
||||
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
|
||||
|
||||
EXPECT_TRUE(found_match);
|
||||
EXPECT_FALSE(matched_nodes_map.empty());
|
||||
EXPECT_FALSE(remove_node_indices.empty());
|
||||
|
||||
bool all_indices_matched = true;
|
||||
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
|
||||
it++) {
|
||||
auto label = str_util::StripPrefix(it->first, "my_");
|
||||
int matched_node_idx = it->second;
|
||||
int expected_node_idx = graph_view.GetNode(label)->node_index();
|
||||
if (matched_node_idx != expected_node_idx) {
|
||||
all_indices_matched = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(all_indices_matched);
|
||||
}
|
||||
|
||||
// Pattern should not be matched if any of candidate remove nodes has external
|
||||
// dependent.
|
||||
TEST_F(PatternMatcherTest, DAGExternalDependent) {
|
||||
// A Data flow graph. Data flows from top to bottom. Here A, B, C, D, E, and F
|
||||
// are ops.
|
||||
//
|
||||
// Input graph Subgraph for pattern matcher
|
||||
//
|
||||
// A
|
||||
// | B
|
||||
// B / \
|
||||
// / \ C D
|
||||
// C D \ /
|
||||
// \ / \ E
|
||||
// E F
|
||||
//
|
||||
// E is the root of pattern syntax as shown below that the pattern matcher
|
||||
// would match. Note D is a candidate for remove node as mentioned in the
|
||||
// syntax. So Pattern matcher should not find a match.
|
||||
// {"E", "my_e", NodeStatus::Replace,
|
||||
// {
|
||||
// {"C", "my_c", NodeStatus::kRemove,
|
||||
// {
|
||||
// {"B", "my_b", NodeStatus::kRemove}
|
||||
// }
|
||||
// },
|
||||
// {"D", "my_d", NodeStatus::kRemove,
|
||||
// {
|
||||
// {"B", "my_b", NodeStatus::kRemove}
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
::tensorflow::Status status;
|
||||
GraphDef graph = CreateGraph({{"f", "F", {"d"}},
|
||||
{"e", "E", {"c", "d"}},
|
||||
{"c", "C", {"b"}},
|
||||
{"d", "D", {"b"}},
|
||||
{"b", "B", {"a"}},
|
||||
{"a", "A", {}}});
|
||||
// clang-format off
|
||||
OpTypePattern pattern{"E", "my_e", NodeStatus::kReplace,
|
||||
{
|
||||
{"C", "my_c", NodeStatus::kRemove,
|
||||
{
|
||||
{"B", "my_b", NodeStatus::kRemove}
|
||||
}
|
||||
},
|
||||
{"D", "my_d", NodeStatus::kRemove,
|
||||
{
|
||||
{"B", "my_b", NodeStatus::kRemove}
|
||||
}
|
||||
}
|
||||
}
|
||||
}; // clang-format on
|
||||
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("e");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
std::map<string, int> matched_nodes_map; // label to node index map
|
||||
std::set<int> remove_node_indices;
|
||||
bool found_match = graph_matcher.GetMatchedNodes(
|
||||
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
|
||||
|
||||
EXPECT_FALSE(found_match);
|
||||
EXPECT_TRUE(matched_nodes_map.empty());
|
||||
EXPECT_TRUE(remove_node_indices.empty());
|
||||
}
|
||||
|
||||
TEST_F(PatternMatcherTest, MatMulBiasAddGelu) {
|
||||
::tensorflow::Status status;
|
||||
GraphDef graph;
|
||||
GetMatMulBiasAddGeluGraph(&graph);
|
||||
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("gelu");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
std::map<string, int> matched_nodes_map; // label to node index map
|
||||
std::set<int> remove_node_indices;
|
||||
bool found_match = graph_matcher.GetMatchedNodes(
|
||||
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
|
||||
|
||||
EXPECT_TRUE(found_match);
|
||||
EXPECT_FALSE(matched_nodes_map.empty());
|
||||
EXPECT_FALSE(remove_node_indices.empty());
|
||||
|
||||
bool all_indices_matched = true;
|
||||
for (auto it = matched_nodes_map.begin(); it != matched_nodes_map.begin();
|
||||
it++) {
|
||||
auto label = str_util::StripPrefix(it->first, "my_");
|
||||
int matched_node_idx = it->second;
|
||||
int expected_node_idx = graph_view.GetNode(label)->node_index();
|
||||
if (matched_node_idx != expected_node_idx) {
|
||||
all_indices_matched = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(all_indices_matched);
|
||||
}
|
||||
|
||||
// Pattern should not be matched if any of candidate remove nodes has external
|
||||
// dependent.
|
||||
TEST_F(PatternMatcherTest, MatMulBiasAddGeluExternalDependent) {
|
||||
::tensorflow::Status status;
|
||||
GraphDef graph;
|
||||
GetMatMulBiasAddGeluGraph(&graph, /*add_external_dependent=*/true);
|
||||
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("gelu");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
std::map<string, int> matched_nodes_map; // label to node index map
|
||||
std::set<int> remove_node_indices;
|
||||
bool found_match = graph_matcher.GetMatchedNodes(
|
||||
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
|
||||
|
||||
EXPECT_FALSE(found_match);
|
||||
EXPECT_TRUE(matched_nodes_map.empty());
|
||||
EXPECT_TRUE(remove_node_indices.empty());
|
||||
}
|
||||
|
||||
TEST_F(PatternMatcherTest, MatMulBiasAddGeluMutation) {
|
||||
::tensorflow::Status status;
|
||||
GraphDef graph;
|
||||
GetMatMulBiasAddGeluGraph(&graph);
|
||||
OpTypePattern pattern = GetMatMulBiasAddGeluPattern();
|
||||
MutableGraphView graph_view(&graph, &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(graph_view.SortTopologically(/*ignore_cycles=*/false, {}));
|
||||
auto root_node_view = graph_view.GetNode("gelu");
|
||||
|
||||
SubGraphMatcher<MatchingDirection::kFollowInputs> graph_matcher(&graph_view);
|
||||
std::map<string, int> matched_nodes_map; // label to node index map
|
||||
std::set<int> remove_node_indices;
|
||||
bool found_match = graph_matcher.GetMatchedNodes(
|
||||
pattern, root_node_view, &matched_nodes_map, &remove_node_indices);
|
||||
EXPECT_TRUE(found_match);
|
||||
EXPECT_FALSE(matched_nodes_map.empty());
|
||||
EXPECT_FALSE(remove_node_indices.empty());
|
||||
|
||||
// Before mutation number of nodes.
|
||||
int num_nodes_before = graph_view.NumNodes();
|
||||
// Before mutation node_names of the remove candidate nodes.
|
||||
std::vector<string> remove_node_names;
|
||||
for (auto const& node_idx : remove_node_indices) {
|
||||
remove_node_names.push_back(graph_view.GetNode(node_idx)->GetName());
|
||||
}
|
||||
|
||||
Mutation* mutation = graph_view.GetMutationBuilder();
|
||||
// Replace with fused op.
|
||||
NodeDef fused_node;
|
||||
fused_node.set_name("gelu");
|
||||
fused_node.set_op("_FusedMatMul");
|
||||
fused_node.add_input(graph_view.GetNode("matmul")->node()->input(0));
|
||||
fused_node.add_input(graph_view.GetNode("matmul")->node()->input(1));
|
||||
fused_node.add_input(graph_view.GetNode("bias_add")->node()->input(1));
|
||||
mutation->AddNode(std::move(fused_node), &status);
|
||||
TF_ASSERT_OK(status);
|
||||
TF_EXPECT_OK(mutation->Apply());
|
||||
// Remove nodes that are marked as NodeStatus::kRemove.
|
||||
for (auto const& node_idx : remove_node_indices) {
|
||||
mutation->RemoveNode(graph_view.GetNode(node_idx));
|
||||
}
|
||||
TF_EXPECT_OK(mutation->Apply());
|
||||
|
||||
// After mutation number of nodes.
|
||||
int num_nodes_after = graph_view.NumNodes();
|
||||
EXPECT_EQ(num_nodes_before - remove_node_indices.size(), num_nodes_after);
|
||||
|
||||
bool remove_nodes_deleted = true;
|
||||
for (auto const& node_name : remove_node_names) {
|
||||
if (graph_view.GetNode(node_name) != nullptr) {
|
||||
remove_nodes_deleted = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(remove_nodes_deleted);
|
||||
|
||||
bool replace_node_exist = graph_view.HasNode("gelu") ? true : false;
|
||||
EXPECT_TRUE(replace_node_exist);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace utils
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user