[Grappler] Split common subgraph elimination out into a separate Grappler optimizer. This allows us to run it before constant folding and function optimizer to reduce the size of the graph earlier and save time spent in grappler. In particular, this reduces the size of the graph for which we run static shape inference, which is very expensive.
Change NodeMap to use absl::node_hash_map instead of std::unordered_map. On a particular model with 215k nodes and 255k edges, I measure a 25% speedup of the Grappler MetaOptimizer overall. PiperOrigin-RevId: 299413265 Change-Id: I741e8f22fa8169044d3d51b81e08d20df301d506
This commit is contained in:
parent
5db7460fa8
commit
87dff7f00d
@ -40,6 +40,7 @@ cc_library(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/container:node_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
|
@ -695,7 +695,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
|
||||
|
||||
bool ModifiesInputsInPlace(const NodeDef& node) {
|
||||
// Some nodes do in-place updates on regular tensor inputs.
|
||||
string op_name = node.op();
|
||||
const string& op_name = node.op();
|
||||
|
||||
// Ops that modify resource variables effectively modify one of their inputs.
|
||||
if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
|
||||
@ -706,8 +706,10 @@ bool ModifiesInputsInPlace(const NodeDef& node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
|
||||
if (absl::StrContains(op_name, "inplace")) {
|
||||
string lower_op_name = op_name;
|
||||
std::transform(lower_op_name.begin(), lower_op_name.end(),
|
||||
lower_op_name.begin(), ::tolower);
|
||||
if (absl::StrContains(lower_op_name, "inplace")) {
|
||||
return true;
|
||||
}
|
||||
return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
|
||||
@ -862,20 +864,25 @@ bool NeverForwardsInputs(const NodeDef& node) {
|
||||
(new gtl::FlatSet<string>{"ArgMax",
|
||||
"ArgMin",
|
||||
"AudioSpectrogram",
|
||||
"AvgPool",
|
||||
"BatchMatMul",
|
||||
"BatchMatMulV2",
|
||||
"BatchNormWithGlobalNormalization",
|
||||
"BatchToSpace",
|
||||
"BatchToSpaceND",
|
||||
"Bincount",
|
||||
"BroadcastArgs",
|
||||
"BroadcastGradientArgs",
|
||||
"Bucketize",
|
||||
"CTCBeamSearchDecoder",
|
||||
"CTCGreedyDecoder",
|
||||
"CTCLoss",
|
||||
"CompareAndBitpack",
|
||||
"ComplexAbs",
|
||||
"Concat",
|
||||
"ConcatOffset",
|
||||
"ConcatV2",
|
||||
"Conv2D",
|
||||
"Copy",
|
||||
"CopyHost",
|
||||
"Cross",
|
||||
@ -890,8 +897,8 @@ bool NeverForwardsInputs(const NodeDef& node) {
|
||||
"CudnnRNNParamsToCanonicalV2",
|
||||
"CudnnRNNV2",
|
||||
"CudnnRNNV3",
|
||||
"CumSum",
|
||||
"CumProd",
|
||||
"CumSum",
|
||||
"DebugNanCount",
|
||||
"DebugNumericSummary",
|
||||
"DecodeProtoV2",
|
||||
@ -920,15 +927,25 @@ bool NeverForwardsInputs(const NodeDef& node) {
|
||||
"LowerBound",
|
||||
"MatMul",
|
||||
"MatrixDiag",
|
||||
"MatrixDiagV2",
|
||||
"MatrixDiagPart",
|
||||
"MatrixDiagPartV2",
|
||||
"MatrixDiagV2",
|
||||
"Mfcc",
|
||||
"Multinomial",
|
||||
"OneHot",
|
||||
"Pack",
|
||||
"ParameterizedTruncatedNormal",
|
||||
"PopulationCount",
|
||||
"RandomGamma",
|
||||
"RandomPoisson",
|
||||
"RandomPoissonV2",
|
||||
"RandomStandardNormal",
|
||||
"RandomUniform",
|
||||
"RandomUniformInt",
|
||||
"Range",
|
||||
"Rank",
|
||||
"RequantizationRange",
|
||||
"Requantize",
|
||||
"ReverseSequence",
|
||||
"Shape",
|
||||
"ShapeN",
|
||||
@ -939,6 +956,7 @@ bool NeverForwardsInputs(const NodeDef& node) {
|
||||
"SparseMatMul",
|
||||
"Split",
|
||||
"SplitV",
|
||||
"TruncatedNormal",
|
||||
"Unique",
|
||||
"UniqueV2",
|
||||
"UniqueWithCounts",
|
||||
@ -946,23 +964,7 @@ bool NeverForwardsInputs(const NodeDef& node) {
|
||||
"Unpack",
|
||||
"UnravelIndex",
|
||||
"UpperBound",
|
||||
"Where",
|
||||
"CompareAndBitpack",
|
||||
"Requantize",
|
||||
"RequantizationRange",
|
||||
"Bucketize",
|
||||
"AvgPool",
|
||||
"BatchNormWithGlobalNormalization",
|
||||
"Conv2D",
|
||||
"RandomUniform",
|
||||
"RandomUniformInt",
|
||||
"RandomStandardNormal",
|
||||
"ParameterizedTruncatedNormal",
|
||||
"TruncatedNormal",
|
||||
"Multinomial",
|
||||
"RandomGamma",
|
||||
"RandomPoisson",
|
||||
"RandomPoissonV2"}));
|
||||
"Where"}));
|
||||
const string& op_name = node.op();
|
||||
return kNonForwardingOps->count(op_name) > 0 ||
|
||||
absl::StrContains(op_name, "Segment") ||
|
||||
|
@ -302,6 +302,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":arithmetic_optimizer",
|
||||
":common_subgraph_elimination",
|
||||
":constant_folding",
|
||||
":model_pruner",
|
||||
"//tensorflow/core:test",
|
||||
@ -334,6 +335,59 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common_subgraph_elimination",
|
||||
srcs = ["common_subgraph_elimination.cc"],
|
||||
hdrs = [
|
||||
"common_subgraph_elimination.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:graph_topology_view",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:op_types",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/utils:canonicalizer",
|
||||
"//tensorflow/core/grappler/utils:topological_sort",
|
||||
"//tensorflow/core/grappler/utils:traversal",
|
||||
"//tensorflow/core/platform:hash",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "common_subgraph_elimination_test",
|
||||
size = "small",
|
||||
srcs = ["common_subgraph_elimination_test.cc"],
|
||||
deps = [
|
||||
":arithmetic_optimizer_test_utils",
|
||||
":common_subgraph_elimination",
|
||||
":model_pruner",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/core:all_kernels",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
"//tensorflow/core/grappler/utils:grappler_test",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dependency_optimizer",
|
||||
srcs = ["dependency_optimizer.cc"],
|
||||
@ -605,6 +659,7 @@ cc_library(
|
||||
":arithmetic_optimizer",
|
||||
":auto_mixed_precision",
|
||||
":auto_parallel",
|
||||
":common_subgraph_elimination",
|
||||
":constant_folding",
|
||||
":custom_graph_optimizer_registry",
|
||||
":debug_stripper",
|
||||
|
@ -3462,215 +3462,6 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage {
|
||||
|
||||
} // namespace
|
||||
|
||||
class UniqueNodes {
|
||||
public:
|
||||
NodeDef* FindOrAddRepresentative(NodeDef* node) {
|
||||
uint64 sig = ComputeSignature(*node);
|
||||
std::vector<NodeDef*>& candidates = rep_[sig];
|
||||
for (auto& candidate : candidates) {
|
||||
if ((candidate == node) || SameNode(*candidate, *node)) {
|
||||
return candidate;
|
||||
}
|
||||
}
|
||||
candidates.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RemoveRepresentative(NodeDef* node) {
|
||||
auto it = memoized_signatures_.find(node);
|
||||
if (it == memoized_signatures_.end()) return;
|
||||
|
||||
std::vector<NodeDef*>& candidates = rep_[it->second];
|
||||
for (int i = 0; i < candidates.size(); ++i) {
|
||||
if (candidates[i] == node) {
|
||||
std::swap(candidates[i], candidates[candidates.size() - 1]);
|
||||
candidates.resize(candidates.size() - 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
memoized_signatures_.erase(node);
|
||||
}
|
||||
|
||||
private:
|
||||
uint64 ComputeSignature(const NodeDef& node);
|
||||
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
|
||||
|
||||
absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
|
||||
absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
|
||||
};
|
||||
|
||||
uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
|
||||
auto it = memoized_signatures_.find(&node);
|
||||
if (it != memoized_signatures_.end()) return it->second;
|
||||
|
||||
uint64 h = Hash64(node.op());
|
||||
h = Hash64Combine(Hash64(node.device()), h);
|
||||
|
||||
for (const auto& input : node.input()) {
|
||||
const TensorId input_tensor = ParseTensorName(input);
|
||||
uint64 input_hash = Hash64Combine(
|
||||
Hash64(input_tensor.node().data(), input_tensor.node().size()),
|
||||
std::hash<int>()(input_tensor.index()));
|
||||
h = Hash64CombineUnordered(input_hash, h);
|
||||
}
|
||||
for (const auto& attr : node.attr()) {
|
||||
uint64 attr_hash =
|
||||
Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
|
||||
h = Hash64CombineUnordered(attr_hash, h);
|
||||
}
|
||||
memoized_signatures_.emplace(&node, h);
|
||||
return h;
|
||||
}
|
||||
|
||||
// PRECONDITION:
|
||||
// Node input orders are assumed to be canonicalized, i.e. control inputs for
|
||||
// all nodes as well as regular inputs for commutative nodes must be sorted.
|
||||
bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
|
||||
if (node1.op() != node2.op()) {
|
||||
return false;
|
||||
}
|
||||
if (node1.device() != node2.device()) {
|
||||
return false;
|
||||
}
|
||||
if (node1.input_size() != node2.input_size()) {
|
||||
return false;
|
||||
}
|
||||
if (node1.attr_size() != node2.attr_size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare inputs.
|
||||
auto it1 = node1.input().begin();
|
||||
auto it2 = node2.input().begin();
|
||||
for (; it1 != node1.input().end(); ++it1, ++it2) {
|
||||
if (*it1 != *it2) return false;
|
||||
}
|
||||
|
||||
// Compare attributes.
|
||||
for (const auto& attr1 : node1.attr()) {
|
||||
auto it = node2.attr().find(attr1.first);
|
||||
if (it == node2.attr().end()) return false;
|
||||
if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (IsEnter(node) || IsExit(node)) {
|
||||
return false;
|
||||
}
|
||||
if (node.device().find("SPU") != string::npos) {
|
||||
return false;
|
||||
}
|
||||
if (IsAssert(node) || IsPrint(node)) {
|
||||
return true;
|
||||
}
|
||||
return IsFreeOfSideEffect(node);
|
||||
}
|
||||
|
||||
void ArithmeticOptimizer::DedupComputations() {
|
||||
CanonicalizeGraph(optimized_graph_);
|
||||
|
||||
GraphTopologyView graph_view;
|
||||
if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
|
||||
LOG(WARNING) << "Failed to initialize GraphTopologyView.";
|
||||
return;
|
||||
}
|
||||
|
||||
// Populate feed_inplace_op;
|
||||
absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
|
||||
for (const NodeDef& root : optimized_graph_->node()) {
|
||||
if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
|
||||
|
||||
if (ModifiesInputsInPlace(root)) {
|
||||
const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
|
||||
return node->op() == root.op() || !NeverForwardsInputs(*node);
|
||||
};
|
||||
|
||||
DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
|
||||
DfsPredicates::Advance(is_continue_traversal),
|
||||
DfsCallbacks::PreOrder([&](const NodeDef* node) {
|
||||
feeds_inplace_op.insert(node);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
bool stop = true;
|
||||
std::set<int> duplicates;
|
||||
UniqueNodes nodes;
|
||||
do {
|
||||
stop = true;
|
||||
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
|
||||
if (duplicates.find(i) != duplicates.end()) {
|
||||
continue;
|
||||
}
|
||||
NodeDef* node = optimized_graph_->mutable_node(i);
|
||||
if (!CanDedup(*node) ||
|
||||
feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
|
||||
continue;
|
||||
}
|
||||
NodeDef* rep = nodes.FindOrAddRepresentative(node);
|
||||
if (rep == node) {
|
||||
continue;
|
||||
}
|
||||
// If either node or rep feeds an inplace op, deduping them may cause data
|
||||
// races. For example: If we dedup nodes initializing two independent
|
||||
// inplace accumulations, they will write to the same buffer, clobbering
|
||||
// each other's results.
|
||||
if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
|
||||
continue;
|
||||
}
|
||||
const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name());
|
||||
std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
|
||||
for (NodeDef* fanout : fanouts) {
|
||||
// Update consumers of node.
|
||||
bool updated_fanout = false;
|
||||
for (int i = 0; i < fanout->input_size(); ++i) {
|
||||
string* fanout_input = fanout->mutable_input(i);
|
||||
|
||||
const int position =
|
||||
NodePositionIfSameNode(*fanout_input, node->name());
|
||||
// Update name in-place.
|
||||
if (position < -1) {
|
||||
continue;
|
||||
} else {
|
||||
if (!updated_fanout) {
|
||||
// The signature of the fanout node will change. Remove it from
|
||||
// nodes.
|
||||
nodes.RemoveRepresentative(fanout);
|
||||
}
|
||||
updated_fanout = true;
|
||||
if (position > 0) {
|
||||
*fanout_input = StrCat(rep->name(), ":", position);
|
||||
} else if (position == 0) {
|
||||
*fanout_input = rep->name();
|
||||
} else {
|
||||
*fanout_input = StrCat("^", rep->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (updated_fanout) {
|
||||
node_map_->UpdateInput(fanout->name(), node->name(), rep->name());
|
||||
CanonicalizeNode(fanout);
|
||||
}
|
||||
}
|
||||
duplicates.insert(i);
|
||||
stop = false;
|
||||
}
|
||||
} while (!stop);
|
||||
|
||||
// Delete duplicates
|
||||
if (fetch_nodes_known_ && !duplicates.empty()) {
|
||||
EraseNodesFromGraph(duplicates, optimized_graph_);
|
||||
// Rebuild the NodeMap which was invalidated by the node swapping above.
|
||||
node_map_.reset(new NodeMap(optimized_graph_));
|
||||
}
|
||||
}
|
||||
|
||||
Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
|
||||
SetVector<NodeDef*> nodes_to_simplify;
|
||||
nodes_to_simplify.Reserve(optimized_graph_->node_size());
|
||||
@ -3818,11 +3609,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
|
||||
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
|
||||
if (options_.dedup_computations) {
|
||||
DedupComputations();
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
}
|
||||
|
||||
graph_properties_.reset(new GraphProperties(optimized_item));
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
const Status status =
|
||||
|
@ -104,115 +104,6 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
|
||||
VerifyGraphsMatch(item.graph, output, __LINE__);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, OpDedupping) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
|
||||
Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
|
||||
Output div = ops::Div(s.WithOpName("div"), c1, c2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"div"};
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
ArithmeticOptimizer optimizer;
|
||||
GraphDef output;
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_EQ(output.node_size(), 2);
|
||||
const NodeDef* new_c1 = node_map.GetNode("c1");
|
||||
ASSERT_NE(new_c1, nullptr);
|
||||
|
||||
const NodeDef* new_div = node_map.GetNode("div");
|
||||
ASSERT_NE(new_div, nullptr);
|
||||
ASSERT_EQ(new_div->input_size(), 2);
|
||||
EXPECT_EQ(new_div->input(0), "c1");
|
||||
EXPECT_EQ(new_div->input(1), "c1");
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
|
||||
Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
|
||||
auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
|
||||
auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
|
||||
auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
|
||||
auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
|
||||
Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
|
||||
{assert1.operation, assert2.operation}),
|
||||
check1, check2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"div"};
|
||||
Tensor bool_t(DT_BOOL, TensorShape({}));
|
||||
bool_t.scalar<bool>().setConstant(true);
|
||||
auto tensors_expected =
|
||||
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
ArithmeticOptimizer optimizer;
|
||||
GraphDef output;
|
||||
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 6);
|
||||
const NodeDef* new_div = node_map.GetNode("div");
|
||||
ASSERT_NE(new_div, nullptr);
|
||||
ASSERT_EQ(new_div->input_size(), 3);
|
||||
EXPECT_EQ(new_div->input(0), "check1");
|
||||
EXPECT_EQ(new_div->input(1), "check2");
|
||||
EXPECT_EQ(new_div->input(2), "^assert1");
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
|
||||
EXPECT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
|
||||
Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
|
||||
Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
|
||||
Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
|
||||
Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"div1"};
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
ArithmeticOptimizer optimizer;
|
||||
GraphDef output;
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 4);
|
||||
const NodeDef* new_c1 = node_map.GetNode("c1");
|
||||
ASSERT_NE(new_c1, nullptr);
|
||||
const NodeDef* new_c2 = node_map.GetNode("c2");
|
||||
ASSERT_NE(new_c2, nullptr);
|
||||
const NodeDef* new_mul1 = node_map.GetNode("mul1");
|
||||
ASSERT_NE(new_mul1, nullptr);
|
||||
ASSERT_EQ(new_mul1->input_size(), 2);
|
||||
EXPECT_EQ(new_mul1->input(0), "c1");
|
||||
EXPECT_EQ(new_mul1->input(1), "c2");
|
||||
const NodeDef* new_div1 = node_map.GetNode("div1");
|
||||
ASSERT_NE(new_div1, nullptr);
|
||||
ASSERT_EQ(new_div1->input_size(), 2);
|
||||
EXPECT_EQ(new_div1->input(0), "mul1");
|
||||
EXPECT_EQ(new_div1->input(1), "mul1");
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2});
|
||||
@ -474,6 +365,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
|
||||
Output id = ops::Identity(s.WithOpName("id"), add6);
|
||||
|
||||
GrapplerItem item;
|
||||
item.fetch = {"id"};
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
const std::vector<string> devices{
|
||||
@ -488,16 +380,16 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
|
||||
DisableAddToAddNCombining(&optimizer);
|
||||
|
||||
GraphDef output;
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output);
|
||||
|
||||
// We expect the following rewrite(s) to occur:
|
||||
//
|
||||
// Mul(p,
|
||||
// Add_6(Add_4(Const(2), Const(2)),
|
||||
// Add_5(Const(2), Const(2))))
|
||||
// Add_5(Const(2), Const(2)))
|
||||
NodeMap node_map(&output);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 17);
|
||||
EXPECT_EQ(output.node_size(), 8);
|
||||
|
||||
const NodeDef* id_node = node_map.GetNode("id");
|
||||
ASSERT_NE(id_node, nullptr);
|
||||
@ -507,8 +399,8 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
|
||||
const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
|
||||
ASSERT_NE(mul_node, nullptr);
|
||||
ASSERT_EQ(mul_node->input_size(), 2);
|
||||
EXPECT_EQ(mul_node->input(0), HoistAddName("Add_6"));
|
||||
EXPECT_EQ(mul_node->input(1), "Placeholder");
|
||||
EXPECT_EQ(mul_node->input(0), "Placeholder");
|
||||
EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6"));
|
||||
|
||||
const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
|
||||
ASSERT_NE(add_6_node, nullptr);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
@ -27,9 +28,9 @@ namespace grappler {
|
||||
|
||||
class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
protected:
|
||||
// Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
|
||||
// Optimize a graph using optimizer and prune all the nodes that no
|
||||
// longer have any output consumers.
|
||||
void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
@ -37,8 +38,23 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
|
||||
void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
// Run optimizer twice to make sure the rewrite is idempotent.
|
||||
void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer,
|
||||
GrapplerItem* item, GraphDef* output) {
|
||||
TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
output->Clear();
|
||||
TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run optimizer twice to make sure the rewrite is idempotent.
|
||||
void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
item->graph.Swap(output);
|
||||
@ -46,9 +62,9 @@ class ArithmeticOptimizerTest : public GrapplerTest {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
}
|
||||
|
||||
// Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
|
||||
// Run optimizer twice to make sure the rewrite is idempotent.
|
||||
// Optionally run a constant folding pass before pruning.
|
||||
void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
|
||||
void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item,
|
||||
GraphDef* output, bool const_folding = false) {
|
||||
TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
|
||||
|
||||
|
@ -0,0 +1,291 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
||||
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/graph/tensor_id.h"
|
||||
#include "tensorflow/core/grappler/graph_topology_view.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/op_types.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/canonicalizer.h"
|
||||
#include "tensorflow/core/grappler/utils/topological_sort.h"
|
||||
#include "tensorflow/core/grappler/utils/traversal.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/hash.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/strcat.h"
|
||||
#include "tensorflow/core/platform/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
class Cluster;
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
||||
|
||||
using tensorflow::strings::StrCat;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
class UniqueNodes {
|
||||
public:
|
||||
NodeDef* FindOrAddRepresentative(NodeDef* node) {
|
||||
uint64 sig = ComputeSignature(*node);
|
||||
std::vector<NodeDef*>& candidates = rep_[sig];
|
||||
for (auto& candidate : candidates) {
|
||||
if ((candidate == node) || SameNode(*candidate, *node)) {
|
||||
return candidate;
|
||||
}
|
||||
}
|
||||
candidates.push_back(node);
|
||||
return node;
|
||||
}
|
||||
|
||||
void RemoveRepresentative(NodeDef* node) {
|
||||
auto it = memoized_signatures_.find(node);
|
||||
if (it == memoized_signatures_.end()) return;
|
||||
|
||||
std::vector<NodeDef*>& candidates = rep_[it->second];
|
||||
for (int i = 0; i < candidates.size(); ++i) {
|
||||
if (candidates[i] == node) {
|
||||
std::swap(candidates[i], candidates[candidates.size() - 1]);
|
||||
candidates.resize(candidates.size() - 1);
|
||||
break;
|
||||
}
|
||||
}
|
||||
memoized_signatures_.erase(node);
|
||||
}
|
||||
|
||||
private:
|
||||
uint64 ComputeSignature(const NodeDef& node);
|
||||
bool SameNode(const NodeDef& node1, const NodeDef& node2) const;
|
||||
|
||||
absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_;
|
||||
absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_;
|
||||
};
|
||||
|
||||
uint64 UniqueNodes::ComputeSignature(const NodeDef& node) {
|
||||
auto it = memoized_signatures_.find(&node);
|
||||
if (it != memoized_signatures_.end()) return it->second;
|
||||
|
||||
uint64 h = Hash64(node.op());
|
||||
h = Hash64Combine(Hash64(node.device()), h);
|
||||
|
||||
for (const auto& input : node.input()) {
|
||||
const TensorId input_tensor = ParseTensorName(input);
|
||||
uint64 input_hash = Hash64Combine(
|
||||
Hash64(input_tensor.node().data(), input_tensor.node().size()),
|
||||
std::hash<int>()(input_tensor.index()));
|
||||
h = Hash64CombineUnordered(input_hash, h);
|
||||
}
|
||||
for (const auto& attr : node.attr()) {
|
||||
uint64 attr_hash =
|
||||
Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second));
|
||||
h = Hash64CombineUnordered(attr_hash, h);
|
||||
}
|
||||
memoized_signatures_.emplace(&node, h);
|
||||
return h;
|
||||
}
|
||||
|
||||
// PRECONDITION:
|
||||
// Node input orders are assumed to be canonicalized, i.e. control inputs for
|
||||
// all nodes as well as regular inputs for commutative nodes must be sorted.
|
||||
bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const {
|
||||
if (node1.op() != node2.op()) {
|
||||
return false;
|
||||
}
|
||||
if (node1.device() != node2.device()) {
|
||||
return false;
|
||||
}
|
||||
if (node1.input_size() != node2.input_size()) {
|
||||
return false;
|
||||
}
|
||||
if (node1.attr_size() != node2.attr_size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Compare inputs.
|
||||
auto it1 = node1.input().begin();
|
||||
auto it2 = node2.input().begin();
|
||||
for (; it1 != node1.input().end(); ++it1, ++it2) {
|
||||
if (*it1 != *it2) return false;
|
||||
}
|
||||
|
||||
// Compare attributes.
|
||||
for (const auto& attr1 : node1.attr()) {
|
||||
auto it = node2.attr().find(attr1.first);
|
||||
if (it == node2.attr().end()) return false;
|
||||
if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const {
|
||||
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) {
|
||||
return false;
|
||||
}
|
||||
if (IsEnter(node) || IsExit(node)) {
|
||||
return false;
|
||||
}
|
||||
if (node.device().find("SPU") != string::npos) {
|
||||
return false;
|
||||
}
|
||||
// Workaround for Assert and Print mistakenly being labeled as stateful.
|
||||
if (IsAssert(node) || IsPrint(node)) {
|
||||
return true;
|
||||
}
|
||||
return IsFreeOfSideEffect(node);
|
||||
}
|
||||
|
||||
Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) {
|
||||
CanonicalizeGraph(optimized_graph);
|
||||
|
||||
GraphTopologyView graph_view;
|
||||
if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) {
|
||||
LOG(WARNING) << "Failed to initialize GraphTopologyView.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If either node or rep feeds an inplace op, deduping them may cause data
|
||||
// races. For example: If we dedup nodes initializing two independent
|
||||
// inplace accumulations, they will write to the same buffer, clobbering
|
||||
// each other's results.
|
||||
absl::flat_hash_set<const NodeDef*> feeds_inplace_op;
|
||||
for (int i = 0; i < optimized_graph->node_size(); ++i) {
|
||||
const NodeDef& root = optimized_graph->node(i);
|
||||
if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue;
|
||||
if (ModifiesInputsInPlace(root)) {
|
||||
const auto is_continue_traversal = [&](const NodeDef* node) -> bool {
|
||||
return node->op() == root.op() || !NeverForwardsInputs(*node);
|
||||
};
|
||||
|
||||
DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs,
|
||||
DfsPredicates::Advance(is_continue_traversal),
|
||||
DfsCallbacks::PreOrder([&](const NodeDef* node) {
|
||||
feeds_inplace_op.insert(node);
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<bool> can_dedup(optimized_graph->node_size());
|
||||
for (int i = 0; i < optimized_graph->node_size(); ++i) {
|
||||
const NodeDef& node = optimized_graph->node(i);
|
||||
can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) &&
|
||||
CanDedup(node);
|
||||
}
|
||||
|
||||
bool stop = true;
|
||||
std::set<int> duplicates;
|
||||
UniqueNodes nodes;
|
||||
NodeMap node_map(optimized_graph);
|
||||
do {
|
||||
stop = true;
|
||||
for (int i = 0; i < optimized_graph->node_size(); ++i) {
|
||||
if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) {
|
||||
continue;
|
||||
}
|
||||
NodeDef* node = optimized_graph->mutable_node(i);
|
||||
NodeDef* rep = nodes.FindOrAddRepresentative(node);
|
||||
if (rep == node) {
|
||||
continue;
|
||||
}
|
||||
const std::set<NodeDef*>& tmp = node_map.GetOutputs(node->name());
|
||||
std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end());
|
||||
for (NodeDef* fanout : fanouts) {
|
||||
// Update consumers of node.
|
||||
bool updated_fanout = false;
|
||||
for (int i = 0; i < fanout->input_size(); ++i) {
|
||||
string* fanout_input = fanout->mutable_input(i);
|
||||
|
||||
const int position =
|
||||
NodePositionIfSameNode(*fanout_input, node->name());
|
||||
// Update name in-place.
|
||||
if (position < -1) {
|
||||
continue;
|
||||
} else {
|
||||
if (!updated_fanout) {
|
||||
// The signature of the fanout node will change. Remove it from
|
||||
// nodes.
|
||||
nodes.RemoveRepresentative(fanout);
|
||||
}
|
||||
updated_fanout = true;
|
||||
if (position > 0) {
|
||||
*fanout_input = StrCat(rep->name(), ":", position);
|
||||
} else if (position == 0) {
|
||||
*fanout_input = rep->name();
|
||||
} else {
|
||||
*fanout_input = StrCat("^", rep->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
if (updated_fanout) {
|
||||
node_map.UpdateInput(fanout->name(), node->name(), rep->name());
|
||||
CanonicalizeNode(fanout);
|
||||
}
|
||||
}
|
||||
duplicates.insert(i);
|
||||
stop = false;
|
||||
}
|
||||
} while (!stop);
|
||||
|
||||
// Delete duplicates
|
||||
if (fetch_nodes_known_ && !duplicates.empty()) {
|
||||
EraseNodesFromGraph(duplicates, optimized_graph);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/,
|
||||
const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
// Set up helper data structures.
|
||||
nodes_to_preserve_ = item.NodesToPreserve();
|
||||
fetch_nodes_known_ = !item.fetch.empty();
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
// Perform topological sort on the graph in order to help DedupComputations
|
||||
// optimize larger subgraphs starting from the roots with more inputs.
|
||||
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
|
||||
return DedupComputations(optimized_graph);
|
||||
}
|
||||
|
||||
void CommonSubgraphElimination::Feedback(Cluster* /*cluster*/,
|
||||
const GrapplerItem& /*item*/,
|
||||
const GraphDef& /*optimized_graph*/,
|
||||
double /*result*/) {
|
||||
// Nothing to do for ArithmeticOptimizer.
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -0,0 +1,73 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_
|
||||
|
||||
#include <unordered_set>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/hash.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Optimize TF computations by deduping equivalent subgraphs.
|
||||
class Cluster;
|
||||
struct GrapplerItem;
|
||||
|
||||
class CommonSubgraphElimination : public GraphOptimizer {
|
||||
public:
|
||||
CommonSubgraphElimination() {}
|
||||
|
||||
explicit CommonSubgraphElimination(RewriterConfig::Toggle opt_level)
|
||||
: opt_level_(opt_level) {}
|
||||
|
||||
~CommonSubgraphElimination() override {}
|
||||
|
||||
string name() const override { return "common_subgraph_elimination"; };
|
||||
|
||||
bool UsesFunctionLibrary() const override { return false; }
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) override;
|
||||
|
||||
private:
|
||||
friend class CommonSubgraphEliminationTest;
|
||||
|
||||
// Returns true if it is safe to dedup node from the graph.
|
||||
bool CanDedup(const NodeDef& node) const;
|
||||
|
||||
// Dedup redundant nodes in the graph.
|
||||
Status DedupComputations(GraphDef* optimized_graph);
|
||||
|
||||
RewriterConfig::Toggle opt_level_;
|
||||
|
||||
bool fetch_nodes_known_ = false;
|
||||
std::unordered_set<string> nodes_to_preserve_;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_
|
@ -0,0 +1,178 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/cc/ops/array_ops.h"
|
||||
#include "tensorflow/cc/ops/math_ops.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/grappler/utils/grappler_test.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
namespace {
|
||||
|
||||
void VerifyGraphsMatch(const GraphDef& original_graph,
|
||||
const GraphDef& optimized_graph, int line) {
|
||||
EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
|
||||
for (int i = 0; i < original_graph.node_size(); ++i) {
|
||||
const NodeDef& original = original_graph.node(i);
|
||||
const NodeDef& optimized = optimized_graph.node(i);
|
||||
EXPECT_EQ(original.name(), optimized.name()) << line;
|
||||
EXPECT_EQ(original.op(), optimized.op()) << line;
|
||||
EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
|
||||
for (int j = 0; j < original.input_size(); ++j) {
|
||||
EXPECT_EQ(original.input(j), optimized.input(j)) << line;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
class CommonSubgraphEliminationTest : public ArithmeticOptimizerTest {};
|
||||
|
||||
TEST_F(CommonSubgraphEliminationTest, NoOp) {
|
||||
// This trivial graph is so basic there's nothing to optimize.
|
||||
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
|
||||
GrapplerItem item;
|
||||
CHECK(fake_input.NextItem(&item));
|
||||
|
||||
CommonSubgraphElimination optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
VerifyGraphsMatch(item.graph, output, __LINE__);
|
||||
}
|
||||
|
||||
TEST_F(CommonSubgraphEliminationTest, OpDedupping) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2});
|
||||
Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2});
|
||||
Output div = ops::Div(s.WithOpName("div"), c1, c2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"div"};
|
||||
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
CommonSubgraphElimination optimizer;
|
||||
GraphDef output;
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
EXPECT_EQ(output.node_size(), 2);
|
||||
const NodeDef* new_c1 = node_map.GetNode("c1");
|
||||
ASSERT_NE(new_c1, nullptr);
|
||||
|
||||
const NodeDef* new_div = node_map.GetNode("div");
|
||||
ASSERT_NE(new_div, nullptr);
|
||||
ASSERT_EQ(new_div->input_size(), 2);
|
||||
EXPECT_EQ(new_div->input(0), "c1");
|
||||
EXPECT_EQ(new_div->input(1), "c1");
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(CommonSubgraphEliminationTest, OpDeduppingAssertAndCheckNumerics) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({}));
|
||||
Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2});
|
||||
auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo");
|
||||
auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo");
|
||||
auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c});
|
||||
auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c});
|
||||
Output div = ops::Div(s.WithOpName("div").WithControlDependencies(
|
||||
{assert1.operation, assert2.operation}),
|
||||
check1, check2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"div"};
|
||||
Tensor bool_t(DT_BOOL, TensorShape({}));
|
||||
bool_t.scalar<bool>().setConstant(true);
|
||||
auto tensors_expected =
|
||||
EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}});
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
CommonSubgraphElimination optimizer;
|
||||
GraphDef output;
|
||||
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 6);
|
||||
const NodeDef* new_div = node_map.GetNode("div");
|
||||
ASSERT_NE(new_div, nullptr);
|
||||
ASSERT_EQ(new_div->input_size(), 3);
|
||||
EXPECT_EQ(new_div->input(0), "check1");
|
||||
EXPECT_EQ(new_div->input(1), "check2");
|
||||
EXPECT_EQ(new_div->input(2), "^assert1");
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}});
|
||||
EXPECT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
TEST_F(CommonSubgraphEliminationTest, OpDedupCommutative) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2});
|
||||
Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2});
|
||||
Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2);
|
||||
Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1);
|
||||
Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2);
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"div1"};
|
||||
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
|
||||
ASSERT_EQ(tensors_expected.size(), 1);
|
||||
|
||||
CommonSubgraphElimination optimizer;
|
||||
GraphDef output;
|
||||
OptimizeTwice(&optimizer, &item, &output);
|
||||
NodeMap node_map(&output);
|
||||
|
||||
EXPECT_EQ(output.node_size(), 4);
|
||||
const NodeDef* new_c1 = node_map.GetNode("c1");
|
||||
ASSERT_NE(new_c1, nullptr);
|
||||
const NodeDef* new_c2 = node_map.GetNode("c2");
|
||||
ASSERT_NE(new_c2, nullptr);
|
||||
const NodeDef* new_mul1 = node_map.GetNode("mul1");
|
||||
ASSERT_NE(new_mul1, nullptr);
|
||||
ASSERT_EQ(new_mul1->input_size(), 2);
|
||||
EXPECT_EQ(new_mul1->input(0), "c1");
|
||||
EXPECT_EQ(new_mul1->input(1), "c2");
|
||||
const NodeDef* new_div1 = node_map.GetNode("div1");
|
||||
ASSERT_NE(new_div1, nullptr);
|
||||
ASSERT_EQ(new_div1->input_size(), 2);
|
||||
EXPECT_EQ(new_div1->input(0), "mul1");
|
||||
EXPECT_EQ(new_div1->input(1), "mul1");
|
||||
|
||||
auto tensors = EvaluateNodes(output, item.fetch);
|
||||
ASSERT_EQ(tensors.size(), 1);
|
||||
test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6);
|
||||
}
|
||||
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -383,7 +383,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
op != "TensorArraySizeV3") {
|
||||
continue;
|
||||
}
|
||||
|
||||
const std::vector<OpInfo::TensorProperties>& output =
|
||||
properties.GetOutputProperties(node->name());
|
||||
const std::vector<OpInfo::TensorProperties>& input =
|
||||
@ -410,8 +409,16 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// TODO(rmlarsen): Remove this workaround for b/150861569
|
||||
// The bug involves an expression of the form Shape(ExpandDims(x)
|
||||
// with an incorrectly inferred zero-size first dimension.
|
||||
if (op == "Shape") {
|
||||
if (shape.dims() > 0 && shape.dim_size(0) == 0) continue;
|
||||
}
|
||||
|
||||
// Repurpose the existing node to be the constant.
|
||||
// Device placement is preserved.
|
||||
graph_modified_ = true;
|
||||
node->set_op("Const");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["dtype"].set_type(type);
|
||||
@ -424,9 +431,8 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
// the original graph.
|
||||
string ctrl_dep =
|
||||
AddControlDependency(node->input(0), graph_, node_map_.get());
|
||||
node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep);
|
||||
node->set_input(0, ctrl_dep);
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), node->name());
|
||||
|
||||
// Done with the Shape/Size/Rank node, move to the next node.
|
||||
continue;
|
||||
}
|
||||
@ -458,6 +464,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
continue;
|
||||
}
|
||||
|
||||
graph_modified_ = true;
|
||||
node->set_op("Const");
|
||||
*node->mutable_attr() = array_size->attr();
|
||||
node->set_input(0, AsControlDependency(NodeName(node->input(0))));
|
||||
@ -519,6 +526,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
}
|
||||
*output->mutable_input(k) = const_name;
|
||||
node_map_->AddOutput(const_name, output->name());
|
||||
graph_modified_ = true;
|
||||
}
|
||||
if (node_name == shape_n_node->name() && port != port_idx) {
|
||||
direct_edges_exist = true;
|
||||
@ -3705,8 +3713,9 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
|
||||
}
|
||||
|
||||
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
const GrapplerItem& item,
|
||||
GrapplerItem* item,
|
||||
GraphDef* optimized_graph) {
|
||||
graph_ = &item->graph;
|
||||
node_map_.reset(new NodeMap(graph_));
|
||||
nodes_whitelist_.clear();
|
||||
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
|
||||
@ -3716,14 +3725,14 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
// replace the node with multiple constants (each for one fanout) with
|
||||
// new names, and as a result users would not be able to fetch the node any
|
||||
// more with the original node name.
|
||||
for (const auto& fetch : item.fetch) {
|
||||
for (const auto& fetch : item->fetch) {
|
||||
const NodeDef* fetch_node = node_map_->GetNode(fetch);
|
||||
if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) {
|
||||
nodes_whitelist_.insert(fetch_node->name());
|
||||
}
|
||||
}
|
||||
|
||||
GraphProperties properties(item);
|
||||
GraphProperties properties(*item);
|
||||
// It's possible to feed a placeholder with a tensor of any shape: make sure
|
||||
// that the shape inference deals with this conservatively unless we're in
|
||||
// aggressive mode.
|
||||
@ -3732,15 +3741,18 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_input_tensor_values=*/false,
|
||||
/*include_output_tensor_values=*/true);
|
||||
|
||||
const bool can_use_shape_info = s.ok();
|
||||
|
||||
absl::flat_hash_set<string> nodes_to_not_simplify;
|
||||
if (can_use_shape_info) {
|
||||
TF_RETURN_IF_ERROR(MaterializeShapes(properties));
|
||||
TF_RETURN_IF_ERROR(MaterializeConstants(properties));
|
||||
TF_RETURN_IF_ERROR(
|
||||
FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
|
||||
} else {
|
||||
*optimized_graph = *graph_;
|
||||
}
|
||||
absl::flat_hash_set<string> nodes_to_not_simplify;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FoldGraph(properties, optimized_graph, &nodes_to_not_simplify));
|
||||
node_map_.reset(new NodeMap(optimized_graph));
|
||||
TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph,
|
||||
&properties, &nodes_to_not_simplify));
|
||||
@ -3795,11 +3807,10 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
graph_modified_ = false;
|
||||
item_to_optimize.graph.Swap(optimized_graph);
|
||||
graph_ = &item_to_optimize.graph;
|
||||
*optimized_graph = GraphDef();
|
||||
node_count = graph_->node_size();
|
||||
optimized_graph->Clear();
|
||||
node_count = item_to_optimize.graph.node_size();
|
||||
TF_RETURN_IF_ERROR(
|
||||
RunOptimizationPass(cluster, item_to_optimize, optimized_graph));
|
||||
RunOptimizationPass(cluster, &item_to_optimize, optimized_graph));
|
||||
} while (graph_modified_ || optimized_graph->node_size() != node_count);
|
||||
*optimized_graph->mutable_library() = item.graph.library();
|
||||
*optimized_graph->mutable_versions() = item.graph.versions();
|
||||
|
@ -131,8 +131,8 @@ class ConstantFolding : public GraphOptimizer {
|
||||
Status SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
GraphDef* optimized_graph, GraphProperties* properties);
|
||||
|
||||
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* output);
|
||||
Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item,
|
||||
GraphDef* optimized_graph);
|
||||
|
||||
// Applies partial constant folding for Concat which is not commutative.
|
||||
// Returns true if the transformation applied successfully.
|
||||
|
@ -17,9 +17,12 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
|
||||
#include "tensorflow/core/grappler/optimizers/auto_parallel.h"
|
||||
#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
|
||||
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
|
||||
@ -184,6 +185,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
|
||||
MK_OPT("auto_mixed_precision",
|
||||
new AutoMixedPrecision(cfg_.auto_mixed_precision()));
|
||||
MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
|
||||
MK_OPT("common_subgraph_elimination",
|
||||
new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
|
||||
MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
|
||||
MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
|
||||
MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
|
||||
@ -224,6 +227,11 @@ Status MetaOptimizer::InitializeOptimizers(
|
||||
cfg_.function_optimization(),
|
||||
/*lower_control_flow=*/!IsSingleThreadedExecutor()));
|
||||
}
|
||||
if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF &&
|
||||
cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
|
||||
optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
|
||||
cfg_.common_subgraph_elimination()));
|
||||
}
|
||||
if (cfg_.debug_stripper() == RewriterConfig::ON) {
|
||||
optimizers->push_back(MakeUnique<DebugStripper>());
|
||||
}
|
||||
@ -812,6 +820,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) {
|
||||
rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
|
||||
rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
|
||||
rewrite_cfg.remapping() != RewriterConfig::OFF ||
|
||||
rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
|
||||
rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
|
||||
rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
|
||||
rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/node_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
@ -65,8 +66,8 @@ class NodeMap {
|
||||
|
||||
private:
|
||||
const std::set<NodeDef*> empty_set_;
|
||||
gtl::FlatMap<string, NodeDef*> nodes_;
|
||||
gtl::FlatMap<string, std::set<NodeDef*>> outputs_;
|
||||
absl::node_hash_map<string, NodeDef*> nodes_;
|
||||
absl::node_hash_map<string, std::set<NodeDef*>> outputs_;
|
||||
};
|
||||
|
||||
// A vector with a set. The set stores the same elements as the vector, and
|
||||
|
@ -60,6 +60,9 @@ message RewriterConfig {
|
||||
// Remapping (default is ON)
|
||||
// Remap subgraphs onto more efficient implementations.
|
||||
Toggle remapping = 14;
|
||||
// Common subgraph elimination (default is ON)
|
||||
// e.g. Simplify arithmetic ops; merge ops with same value (like constants).
|
||||
Toggle common_subgraph_elimination = 24;
|
||||
// Arithmetic optimizations (default is ON)
|
||||
// e.g. Simplify arithmetic ops; merge ops with same value (like constants).
|
||||
Toggle arithmetic_optimization = 7;
|
||||
|
Loading…
Reference in New Issue
Block a user