diff --git a/tensorflow/core/grappler/BUILD b/tensorflow/core/grappler/BUILD index f8ab8748285..618cc9a990f 100644 --- a/tensorflow/core/grappler/BUILD +++ b/tensorflow/core/grappler/BUILD @@ -40,6 +40,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:node_hash_map", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index f6384c35360..8b8df527041 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -695,7 +695,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) { bool ModifiesInputsInPlace(const NodeDef& node) { // Some nodes do in-place updates on regular tensor inputs. - string op_name = node.op(); + const string& op_name = node.op(); // Ops that modify resource variables effectively modify one of their inputs. if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" || @@ -706,8 +706,10 @@ bool ModifiesInputsInPlace(const NodeDef& node) { return false; } - std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower); - if (absl::StrContains(op_name, "inplace")) { + string lower_op_name = op_name; + std::transform(lower_op_name.begin(), lower_op_name.end(), + lower_op_name.begin(), ::tolower); + if (absl::StrContains(lower_op_name, "inplace")) { return true; } return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace"); @@ -862,20 +864,25 @@ bool NeverForwardsInputs(const NodeDef& node) { (new gtl::FlatSet<string>{"ArgMax", "ArgMin", "AudioSpectrogram", + "AvgPool", "BatchMatMul", "BatchMatMulV2", + "BatchNormWithGlobalNormalization", "BatchToSpace", "BatchToSpaceND", "Bincount", "BroadcastArgs", "BroadcastGradientArgs", + "Bucketize", "CTCBeamSearchDecoder", "CTCGreedyDecoder", "CTCLoss", + "CompareAndBitpack", "ComplexAbs", "Concat", "ConcatOffset", "ConcatV2", + "Conv2D", "Copy", "CopyHost", "Cross", @@ -890,8 +897,8 @@ bool NeverForwardsInputs(const NodeDef& node) { "CudnnRNNParamsToCanonicalV2", "CudnnRNNV2", "CudnnRNNV3", - "CumSum", "CumProd", + "CumSum", "DebugNanCount", "DebugNumericSummary", "DecodeProtoV2", @@ -920,15 +927,25 @@ bool NeverForwardsInputs(const NodeDef& node) { "LowerBound", "MatMul", "MatrixDiag", - "MatrixDiagV2", "MatrixDiagPart", "MatrixDiagPartV2", + "MatrixDiagV2", "Mfcc", + "Multinomial", "OneHot", "Pack", + "ParameterizedTruncatedNormal", "PopulationCount", + "RandomGamma", + "RandomPoisson", + "RandomPoissonV2", + "RandomStandardNormal", + "RandomUniform", + "RandomUniformInt", "Range", "Rank", + "RequantizationRange", + "Requantize", "ReverseSequence", "Shape", "ShapeN", @@ -939,6 +956,7 @@ bool NeverForwardsInputs(const NodeDef& node) { "SparseMatMul", "Split", "SplitV", + "TruncatedNormal", "Unique", "UniqueV2", "UniqueWithCounts", @@ -946,23 +964,7 @@ bool NeverForwardsInputs(const NodeDef& node) { "Unpack", "UnravelIndex", "UpperBound", - "Where", - "CompareAndBitpack", - "Requantize", - "RequantizationRange", - "Bucketize", - "AvgPool", - "BatchNormWithGlobalNormalization", - "Conv2D", - "RandomUniform", - "RandomUniformInt", - "RandomStandardNormal", - "ParameterizedTruncatedNormal", - "TruncatedNormal", - "Multinomial", - "RandomGamma", - "RandomPoisson", - "RandomPoissonV2"})); + "Where"})); const string& op_name = node.op(); return kNonForwardingOps->count(op_name) > 0 || absl::StrContains(op_name, "Segment") || diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 9be609b3970..56b7754355c 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -302,6 +302,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":arithmetic_optimizer", + ":common_subgraph_elimination", ":constant_folding", ":model_pruner", "//tensorflow/core:test", @@ -334,6 +335,59 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "common_subgraph_elimination", + srcs = ["common_subgraph_elimination.cc"], + hdrs = [ + "common_subgraph_elimination.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_topology_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/utils:canonicalizer", + "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/utils:traversal", + "//tensorflow/core/platform:hash", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + +tf_cuda_cc_test( + name = "common_subgraph_elimination_test", + size = "small", + srcs = ["common_subgraph_elimination_test.cc"], + deps = [ + ":arithmetic_optimizer_test_utils", + ":common_subgraph_elimination", + ":model_pruner", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/core:all_kernels", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "dependency_optimizer", srcs = ["dependency_optimizer.cc"], @@ -605,6 +659,7 @@ cc_library( ":arithmetic_optimizer", ":auto_mixed_precision", ":auto_parallel", + ":common_subgraph_elimination", ":constant_folding", ":custom_graph_optimizer_registry", ":debug_stripper", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 3281f97457f..0b9701ca0c3 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -3462,215 +3462,6 @@ class RemoveStackSliceSameAxis : public ArithmeticOptimizerStage { } // namespace -class UniqueNodes { - public: - NodeDef* FindOrAddRepresentative(NodeDef* node) { - uint64 sig = ComputeSignature(*node); - std::vector<NodeDef*>& candidates = rep_[sig]; - for (auto& candidate : candidates) { - if ((candidate == node) || SameNode(*candidate, *node)) { - return candidate; - } - } - candidates.push_back(node); - return node; - } - - void RemoveRepresentative(NodeDef* node) { - auto it = memoized_signatures_.find(node); - if (it == memoized_signatures_.end()) return; - - std::vector<NodeDef*>& candidates = rep_[it->second]; - for (int i = 0; i < candidates.size(); ++i) { - if (candidates[i] == node) { - std::swap(candidates[i], candidates[candidates.size() - 1]); - candidates.resize(candidates.size() - 1); - break; - } - } - memoized_signatures_.erase(node); - } - - private: - uint64 ComputeSignature(const NodeDef& node); - bool SameNode(const NodeDef& node1, const NodeDef& node2) const; - - absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_; - absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_; -}; - -uint64 UniqueNodes::ComputeSignature(const NodeDef& node) { - auto it = memoized_signatures_.find(&node); - if (it != memoized_signatures_.end()) return it->second; - - uint64 h = Hash64(node.op()); - h = Hash64Combine(Hash64(node.device()), h); - - for (const auto& input : node.input()) { - const TensorId input_tensor = ParseTensorName(input); - uint64 input_hash = Hash64Combine( - Hash64(input_tensor.node().data(), input_tensor.node().size()), - std::hash<int>()(input_tensor.index())); - h = Hash64CombineUnordered(input_hash, h); - } - for (const auto& attr : node.attr()) { - uint64 attr_hash = - Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second)); - h = Hash64CombineUnordered(attr_hash, h); - } - memoized_signatures_.emplace(&node, h); - return h; -} - -// PRECONDITION: -// Node input orders are assumed to be canonicalized, i.e. control inputs for -// all nodes as well as regular inputs for commutative nodes must be sorted. -bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { - if (node1.op() != node2.op()) { - return false; - } - if (node1.device() != node2.device()) { - return false; - } - if (node1.input_size() != node2.input_size()) { - return false; - } - if (node1.attr_size() != node2.attr_size()) { - return false; - } - - // Compare inputs. - auto it1 = node1.input().begin(); - auto it2 = node2.input().begin(); - for (; it1 != node1.input().end(); ++it1, ++it2) { - if (*it1 != *it2) return false; - } - - // Compare attributes. - for (const auto& attr1 : node1.attr()) { - auto it = node2.attr().find(attr1.first); - if (it == node2.attr().end()) return false; - if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false; - } - - return true; -} - -bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const { - if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { - return false; - } - if (IsEnter(node) || IsExit(node)) { - return false; - } - if (node.device().find("SPU") != string::npos) { - return false; - } - if (IsAssert(node) || IsPrint(node)) { - return true; - } - return IsFreeOfSideEffect(node); -} - -void ArithmeticOptimizer::DedupComputations() { - CanonicalizeGraph(optimized_graph_); - - GraphTopologyView graph_view; - if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) { - LOG(WARNING) << "Failed to initialize GraphTopologyView."; - return; - } - - // Populate feed_inplace_op; - absl::flat_hash_set<const NodeDef*> feeds_inplace_op; - for (const NodeDef& root : optimized_graph_->node()) { - if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue; - - if (ModifiesInputsInPlace(root)) { - const auto is_continue_traversal = [&](const NodeDef* node) -> bool { - return node->op() == root.op() || !NeverForwardsInputs(*node); - }; - - DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs, - DfsPredicates::Advance(is_continue_traversal), - DfsCallbacks::PreOrder([&](const NodeDef* node) { - feeds_inplace_op.insert(node); - })); - } - } - - bool stop = true; - std::set<int> duplicates; - UniqueNodes nodes; - do { - stop = true; - for (int i = 0; i < optimized_graph_->node_size(); ++i) { - if (duplicates.find(i) != duplicates.end()) { - continue; - } - NodeDef* node = optimized_graph_->mutable_node(i); - if (!CanDedup(*node) || - feeds_inplace_op.find(node) != feeds_inplace_op.end()) { - continue; - } - NodeDef* rep = nodes.FindOrAddRepresentative(node); - if (rep == node) { - continue; - } - // If either node or rep feeds an inplace op, deduping them may cause data - // races. For example: If we dedup nodes initializing two independent - // inplace accumulations, they will write to the same buffer, clobbering - // each other's results. - if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) { - continue; - } - const std::set<NodeDef*>& tmp = node_map_->GetOutputs(node->name()); - std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end()); - for (NodeDef* fanout : fanouts) { - // Update consumers of node. - bool updated_fanout = false; - for (int i = 0; i < fanout->input_size(); ++i) { - string* fanout_input = fanout->mutable_input(i); - - const int position = - NodePositionIfSameNode(*fanout_input, node->name()); - // Update name in-place. - if (position < -1) { - continue; - } else { - if (!updated_fanout) { - // The signature of the fanout node will change. Remove it from - // nodes. - nodes.RemoveRepresentative(fanout); - } - updated_fanout = true; - if (position > 0) { - *fanout_input = StrCat(rep->name(), ":", position); - } else if (position == 0) { - *fanout_input = rep->name(); - } else { - *fanout_input = StrCat("^", rep->name()); - } - } - } - if (updated_fanout) { - node_map_->UpdateInput(fanout->name(), node->name(), rep->name()); - CanonicalizeNode(fanout); - } - } - duplicates.insert(i); - stop = false; - } - } while (!stop); - - // Delete duplicates - if (fetch_nodes_known_ && !duplicates.empty()) { - EraseNodesFromGraph(duplicates, optimized_graph_); - // Rebuild the NodeMap which was invalidated by the node swapping above. - node_map_.reset(new NodeMap(optimized_graph_)); - } -} - Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) { SetVector<NodeDef*> nodes_to_simplify; nodes_to_simplify.Reserve(optimized_graph_->node_size()); @@ -3818,11 +3609,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_)); GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); - if (options_.dedup_computations) { - DedupComputations(); - GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); - } - graph_properties_.reset(new GraphProperties(optimized_item)); const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; const Status status = diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index a421daa88a5..50896b11923 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -104,115 +104,6 @@ TEST_F(ArithmeticOptimizerTest, NoOp) { VerifyGraphsMatch(item.graph, output, __LINE__); } -TEST_F(ArithmeticOptimizerTest, OpDedupping) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2}); - Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2}); - Output div = ops::Div(s.WithOpName("div"), c1, c2); - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"div"}; - - auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - ASSERT_EQ(tensors_expected.size(), 1); - - ArithmeticOptimizer optimizer; - GraphDef output; - OptimizeTwice(&optimizer, &item, &output); - NodeMap node_map(&output); - EXPECT_EQ(output.node_size(), 2); - const NodeDef* new_c1 = node_map.GetNode("c1"); - ASSERT_NE(new_c1, nullptr); - - const NodeDef* new_div = node_map.GetNode("div"); - ASSERT_NE(new_div, nullptr); - ASSERT_EQ(new_div->input_size(), 2); - EXPECT_EQ(new_div->input(0), "c1"); - EXPECT_EQ(new_div->input(1), "c1"); - - auto tensors = EvaluateNodes(output, item.fetch); - ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6); -} - -TEST_F(ArithmeticOptimizerTest, OpDeduppingAssertAndCheckNumerics) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({})); - Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2}); - auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo"); - auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo"); - auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c}); - auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c}); - Output div = ops::Div(s.WithOpName("div").WithControlDependencies( - {assert1.operation, assert2.operation}), - check1, check2); - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"div"}; - Tensor bool_t(DT_BOOL, TensorShape({})); - bool_t.scalar<bool>().setConstant(true); - auto tensors_expected = - EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}}); - ASSERT_EQ(tensors_expected.size(), 1); - - ArithmeticOptimizer optimizer; - GraphDef output; - - OptimizeTwice(&optimizer, &item, &output); - NodeMap node_map(&output); - - EXPECT_EQ(output.node_size(), 6); - const NodeDef* new_div = node_map.GetNode("div"); - ASSERT_NE(new_div, nullptr); - ASSERT_EQ(new_div->input_size(), 3); - EXPECT_EQ(new_div->input(0), "check1"); - EXPECT_EQ(new_div->input(1), "check2"); - EXPECT_EQ(new_div->input(2), "^assert1"); - - auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}}); - EXPECT_EQ(tensors.size(), 1); - test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6); -} - -TEST_F(ArithmeticOptimizerTest, OpDedupCommutative) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2}); - Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2}); - Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2); - Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1); - Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2); - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"div1"}; - auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - ASSERT_EQ(tensors_expected.size(), 1); - - ArithmeticOptimizer optimizer; - GraphDef output; - OptimizeTwice(&optimizer, &item, &output); - NodeMap node_map(&output); - - EXPECT_EQ(output.node_size(), 4); - const NodeDef* new_c1 = node_map.GetNode("c1"); - ASSERT_NE(new_c1, nullptr); - const NodeDef* new_c2 = node_map.GetNode("c2"); - ASSERT_NE(new_c2, nullptr); - const NodeDef* new_mul1 = node_map.GetNode("mul1"); - ASSERT_NE(new_mul1, nullptr); - ASSERT_EQ(new_mul1->input_size(), 2); - EXPECT_EQ(new_mul1->input(0), "c1"); - EXPECT_EQ(new_mul1->input(1), "c2"); - const NodeDef* new_div1 = node_map.GetNode("div1"); - ASSERT_NE(new_div1, nullptr); - ASSERT_EQ(new_div1->input_size(), 2); - EXPECT_EQ(new_div1->input(0), "mul1"); - EXPECT_EQ(new_div1->input(1), "mul1"); - - auto tensors = EvaluateNodes(output, item.fetch); - ASSERT_EQ(tensors.size(), 1); - test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); -} - TEST_F(ArithmeticOptimizerTest, ReplaceMulWithSquare) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output c = ops::Const(s.WithOpName("c"), {1.0f, 2.0f}, {1, 2}); @@ -474,6 +365,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { Output id = ops::Identity(s.WithOpName("id"), add6); GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); const std::vector<string> devices{ @@ -488,16 +380,16 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { DisableAddToAddNCombining(&optimizer); GraphDef output; - OptimizeTwice(&optimizer, &item, &output); + DedupAndOptimizeTwiceAndPrune(&optimizer, &item, &output); // We expect the following rewrite(s) to occur: // // Mul(p, // Add_6(Add_4(Const(2), Const(2)), - // Add_5(Const(2), Const(2)))) + // Add_5(Const(2), Const(2))) NodeMap node_map(&output); - EXPECT_EQ(output.node_size(), 17); + EXPECT_EQ(output.node_size(), 8); const NodeDef* id_node = node_map.GetNode("id"); ASSERT_NE(id_node, nullptr); @@ -507,8 +399,8 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6")); ASSERT_NE(mul_node, nullptr); ASSERT_EQ(mul_node->input_size(), 2); - EXPECT_EQ(mul_node->input(0), HoistAddName("Add_6")); - EXPECT_EQ(mul_node->input(1), "Placeholder"); + EXPECT_EQ(mul_node->input(0), "Placeholder"); + EXPECT_EQ(mul_node->input(1), HoistAddName("Add_6")); const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6")); ASSERT_NE(add_6_node, nullptr); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h index 4d3ba976c4f..73bb5a0d97c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_ARITHMETIC_OPTIMIZER_TEST_UTILS_H_ #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" +#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" #include "tensorflow/core/grappler/utils/grappler_test.h" @@ -27,9 +28,9 @@ namespace grappler { class ArithmeticOptimizerTest : public GrapplerTest { protected: - // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no + // Optimize a graph using optimizer and prune all the nodes that no // longer have any output consumers. - void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, + void OptimizeAndPrune(GraphOptimizer* optimizer, GrapplerItem* item, GraphDef* output) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); item->graph.Swap(output); @@ -37,8 +38,23 @@ class ArithmeticOptimizerTest : public GrapplerTest { TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); } - // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. - void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item, + // Run optimizer twice to make sure the rewrite is idempotent. + void DedupAndOptimizeTwiceAndPrune(GraphOptimizer* optimizer, + GrapplerItem* item, GraphDef* output) { + TF_EXPECT_OK(CommonSubgraphElimination().Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + output->Clear(); + TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); + } + + // Run optimizer twice to make sure the rewrite is idempotent. + void OptimizeTwice(GraphOptimizer* optimizer, GrapplerItem* item, GraphDef* output) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); item->graph.Swap(output); @@ -46,9 +62,9 @@ class ArithmeticOptimizerTest : public GrapplerTest { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); } - // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. + // Run optimizer twice to make sure the rewrite is idempotent. // Optionally run a constant folding pass before pruning. - void OptimizeTwiceAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item, + void OptimizeTwiceAndPrune(GraphOptimizer* optimizer, GrapplerItem* item, GraphDef* output, bool const_folding = false) { TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc new file mode 100644 index 00000000000..8924e4c6bea --- /dev/null +++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.cc @@ -0,0 +1,291 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" + +#include <set> +#include <string> +#include <unordered_set> +#include <utility> +#include <vector> + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/grappler/graph_topology_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/canonicalizer.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/grappler/utils/traversal.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/hash.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/strcat.h" +#include "tensorflow/core/platform/stringpiece.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace grappler { +class Cluster; +} // namespace grappler +} // namespace tensorflow + +using tensorflow::strings::StrCat; + +namespace tensorflow { +namespace grappler { + +class UniqueNodes { + public: + NodeDef* FindOrAddRepresentative(NodeDef* node) { + uint64 sig = ComputeSignature(*node); + std::vector<NodeDef*>& candidates = rep_[sig]; + for (auto& candidate : candidates) { + if ((candidate == node) || SameNode(*candidate, *node)) { + return candidate; + } + } + candidates.push_back(node); + return node; + } + + void RemoveRepresentative(NodeDef* node) { + auto it = memoized_signatures_.find(node); + if (it == memoized_signatures_.end()) return; + + std::vector<NodeDef*>& candidates = rep_[it->second]; + for (int i = 0; i < candidates.size(); ++i) { + if (candidates[i] == node) { + std::swap(candidates[i], candidates[candidates.size() - 1]); + candidates.resize(candidates.size() - 1); + break; + } + } + memoized_signatures_.erase(node); + } + + private: + uint64 ComputeSignature(const NodeDef& node); + bool SameNode(const NodeDef& node1, const NodeDef& node2) const; + + absl::flat_hash_map<uint64, std::vector<NodeDef*>> rep_; + absl::flat_hash_map<const NodeDef*, uint64> memoized_signatures_; +}; + +uint64 UniqueNodes::ComputeSignature(const NodeDef& node) { + auto it = memoized_signatures_.find(&node); + if (it != memoized_signatures_.end()) return it->second; + + uint64 h = Hash64(node.op()); + h = Hash64Combine(Hash64(node.device()), h); + + for (const auto& input : node.input()) { + const TensorId input_tensor = ParseTensorName(input); + uint64 input_hash = Hash64Combine( + Hash64(input_tensor.node().data(), input_tensor.node().size()), + std::hash<int>()(input_tensor.index())); + h = Hash64CombineUnordered(input_hash, h); + } + for (const auto& attr : node.attr()) { + uint64 attr_hash = + Hash64Combine(Hash64(attr.first), FastAttrValueHash(attr.second)); + h = Hash64CombineUnordered(attr_hash, h); + } + memoized_signatures_.emplace(&node, h); + return h; +} + +// PRECONDITION: +// Node input orders are assumed to be canonicalized, i.e. control inputs for +// all nodes as well as regular inputs for commutative nodes must be sorted. +bool UniqueNodes::SameNode(const NodeDef& node1, const NodeDef& node2) const { + if (node1.op() != node2.op()) { + return false; + } + if (node1.device() != node2.device()) { + return false; + } + if (node1.input_size() != node2.input_size()) { + return false; + } + if (node1.attr_size() != node2.attr_size()) { + return false; + } + + // Compare inputs. + auto it1 = node1.input().begin(); + auto it2 = node2.input().begin(); + for (; it1 != node1.input().end(); ++it1, ++it2) { + if (*it1 != *it2) return false; + } + + // Compare attributes. + for (const auto& attr1 : node1.attr()) { + auto it = node2.attr().find(attr1.first); + if (it == node2.attr().end()) return false; + if (!FastAreAttrValuesEqual(attr1.second, it->second)) return false; + } + + return true; +} + +bool CommonSubgraphElimination::CanDedup(const NodeDef& node) const { + if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end()) { + return false; + } + if (IsEnter(node) || IsExit(node)) { + return false; + } + if (node.device().find("SPU") != string::npos) { + return false; + } + // Workaround for Assert and Print mistakenly being labeled as stateful. + if (IsAssert(node) || IsPrint(node)) { + return true; + } + return IsFreeOfSideEffect(node); +} + +Status CommonSubgraphElimination::DedupComputations(GraphDef* optimized_graph) { + CanonicalizeGraph(optimized_graph); + + GraphTopologyView graph_view; + if (!graph_view.InitializeFromGraph(*optimized_graph).ok()) { + LOG(WARNING) << "Failed to initialize GraphTopologyView."; + return Status::OK(); + } + + // If either node or rep feeds an inplace op, deduping them may cause data + // races. For example: If we dedup nodes initializing two independent + // inplace accumulations, they will write to the same buffer, clobbering + // each other's results. + absl::flat_hash_set<const NodeDef*> feeds_inplace_op; + for (int i = 0; i < optimized_graph->node_size(); ++i) { + const NodeDef& root = optimized_graph->node(i); + if (feeds_inplace_op.find(&root) != feeds_inplace_op.end()) continue; + if (ModifiesInputsInPlace(root)) { + const auto is_continue_traversal = [&](const NodeDef* node) -> bool { + return node->op() == root.op() || !NeverForwardsInputs(*node); + }; + + DfsTraversal(graph_view, {&root}, TraversalDirection::kFollowInputs, + DfsPredicates::Advance(is_continue_traversal), + DfsCallbacks::PreOrder([&](const NodeDef* node) { + feeds_inplace_op.insert(node); + })); + } + } + + std::vector<bool> can_dedup(optimized_graph->node_size()); + for (int i = 0; i < optimized_graph->node_size(); ++i) { + const NodeDef& node = optimized_graph->node(i); + can_dedup[i] = (feeds_inplace_op.find(&node) == feeds_inplace_op.end()) && + CanDedup(node); + } + + bool stop = true; + std::set<int> duplicates; + UniqueNodes nodes; + NodeMap node_map(optimized_graph); + do { + stop = true; + for (int i = 0; i < optimized_graph->node_size(); ++i) { + if (!can_dedup[i] || duplicates.find(i) != duplicates.end()) { + continue; + } + NodeDef* node = optimized_graph->mutable_node(i); + NodeDef* rep = nodes.FindOrAddRepresentative(node); + if (rep == node) { + continue; + } + const std::set<NodeDef*>& tmp = node_map.GetOutputs(node->name()); + std::vector<NodeDef*> fanouts(tmp.begin(), tmp.end()); + for (NodeDef* fanout : fanouts) { + // Update consumers of node. + bool updated_fanout = false; + for (int i = 0; i < fanout->input_size(); ++i) { + string* fanout_input = fanout->mutable_input(i); + + const int position = + NodePositionIfSameNode(*fanout_input, node->name()); + // Update name in-place. + if (position < -1) { + continue; + } else { + if (!updated_fanout) { + // The signature of the fanout node will change. Remove it from + // nodes. + nodes.RemoveRepresentative(fanout); + } + updated_fanout = true; + if (position > 0) { + *fanout_input = StrCat(rep->name(), ":", position); + } else if (position == 0) { + *fanout_input = rep->name(); + } else { + *fanout_input = StrCat("^", rep->name()); + } + } + } + if (updated_fanout) { + node_map.UpdateInput(fanout->name(), node->name(), rep->name()); + CanonicalizeNode(fanout); + } + } + duplicates.insert(i); + stop = false; + } + } while (!stop); + + // Delete duplicates + if (fetch_nodes_known_ && !duplicates.empty()) { + EraseNodesFromGraph(duplicates, optimized_graph); + } + + return Status::OK(); +} + +Status CommonSubgraphElimination::Optimize(Cluster* /*cluster*/, + const GrapplerItem& item, + GraphDef* optimized_graph) { + // Set up helper data structures. + nodes_to_preserve_ = item.NodesToPreserve(); + fetch_nodes_known_ = !item.fetch.empty(); + *optimized_graph = item.graph; + + // Perform topological sort on the graph in order to help DedupComputations + // optimize larger subgraphs starting from the roots with more inputs. + TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); + GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); + + return DedupComputations(optimized_graph); +} + +void CommonSubgraphElimination::Feedback(Cluster* /*cluster*/, + const GrapplerItem& /*item*/, + const GraphDef& /*optimized_graph*/, + double /*result*/) { + // Nothing to do for ArithmeticOptimizer. +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h new file mode 100644 index 00000000000..eec6ba79b3f --- /dev/null +++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination.h @@ -0,0 +1,73 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_ + +#include <unordered_set> + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/platform/hash.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { + +// Optimize TF computations by deduping equivalent subgraphs. +class Cluster; +struct GrapplerItem; + +class CommonSubgraphElimination : public GraphOptimizer { + public: + CommonSubgraphElimination() {} + + explicit CommonSubgraphElimination(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + + ~CommonSubgraphElimination() override {} + + string name() const override { return "common_subgraph_elimination"; }; + + bool UsesFunctionLibrary() const override { return false; } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override; + + private: + friend class CommonSubgraphEliminationTest; + + // Returns true if it is safe to dedup node from the graph. + bool CanDedup(const NodeDef& node) const; + + // Dedup redundant nodes in the graph. + Status DedupComputations(GraphDef* optimized_graph); + + RewriterConfig::Toggle opt_level_; + + bool fetch_nodes_known_ = false; + std::unordered_set<string> nodes_to_preserve_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_COMMON_SUBGRAPH_ELIMINATION_H_ diff --git a/tensorflow/core/grappler/optimizers/common_subgraph_elimination_test.cc b/tensorflow/core/grappler/optimizers/common_subgraph_elimination_test.cc new file mode 100644 index 00000000000..3341a8abe56 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/common_subgraph_elimination_test.cc @@ -0,0 +1,178 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" + +#include "absl/strings/str_cat.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/math_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer_test_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { + +namespace { + +void VerifyGraphsMatch(const GraphDef& original_graph, + const GraphDef& optimized_graph, int line) { + EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line; + for (int i = 0; i < original_graph.node_size(); ++i) { + const NodeDef& original = original_graph.node(i); + const NodeDef& optimized = optimized_graph.node(i); + EXPECT_EQ(original.name(), optimized.name()) << line; + EXPECT_EQ(original.op(), optimized.op()) << line; + EXPECT_EQ(original.input_size(), optimized.input_size()) << line; + for (int j = 0; j < original.input_size(); ++j) { + EXPECT_EQ(original.input(j), optimized.input(j)) << line; + } + } +} +} // namespace + +class CommonSubgraphEliminationTest : public ArithmeticOptimizerTest {}; + +TEST_F(CommonSubgraphEliminationTest, NoOp) { + // This trivial graph is so basic there's nothing to optimize. + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + CommonSubgraphElimination optimizer; + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + VerifyGraphsMatch(item.graph, output, __LINE__); +} + +TEST_F(CommonSubgraphEliminationTest, OpDedupping) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c1 = ops::Const(s.WithOpName("c1"), {3.14, 2.7}, {1, 2}); + Output c2 = ops::Const(s.WithOpName("c2"), {3.14, 2.7}, {1, 2}); + Output div = ops::Div(s.WithOpName("div"), c1, c2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"div"}; + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(tensors_expected.size(), 1); + + CommonSubgraphElimination optimizer; + GraphDef output; + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); + EXPECT_EQ(output.node_size(), 2); + const NodeDef* new_c1 = node_map.GetNode("c1"); + ASSERT_NE(new_c1, nullptr); + + const NodeDef* new_div = node_map.GetNode("div"); + ASSERT_NE(new_div, nullptr); + ASSERT_EQ(new_div->input_size(), 2); + EXPECT_EQ(new_div->input(0), "c1"); + EXPECT_EQ(new_div->input(1), "c1"); + + auto tensors = EvaluateNodes(output, item.fetch); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6); +} + +TEST_F(CommonSubgraphEliminationTest, OpDeduppingAssertAndCheckNumerics) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output p = ops::Placeholder(s, DT_BOOL, ops::Placeholder::Shape({})); + Output c = ops::Const(s.WithOpName("c"), {3.14, 2.7}, {1, 2}); + auto check1 = ops::CheckNumerics(s.WithOpName("check1"), c, "foo"); + auto check2 = ops::CheckNumerics(s.WithOpName("check2"), c, "foo"); + auto assert1 = ops::Assert(s.WithOpName("assert1"), p, {c}); + auto assert2 = ops::Assert(s.WithOpName("assert2"), p, {c}); + Output div = ops::Div(s.WithOpName("div").WithControlDependencies( + {assert1.operation, assert2.operation}), + check1, check2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"div"}; + Tensor bool_t(DT_BOOL, TensorShape({})); + bool_t.scalar<bool>().setConstant(true); + auto tensors_expected = + EvaluateNodes(item.graph, item.fetch, {{"Placeholder", bool_t}}); + ASSERT_EQ(tensors_expected.size(), 1); + + CommonSubgraphElimination optimizer; + GraphDef output; + + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); + + EXPECT_EQ(output.node_size(), 6); + const NodeDef* new_div = node_map.GetNode("div"); + ASSERT_NE(new_div, nullptr); + ASSERT_EQ(new_div->input_size(), 3); + EXPECT_EQ(new_div->input(0), "check1"); + EXPECT_EQ(new_div->input(1), "check2"); + EXPECT_EQ(new_div->input(2), "^assert1"); + + auto tensors = EvaluateNodes(output, item.fetch, {{"Placeholder", bool_t}}); + EXPECT_EQ(tensors.size(), 1); + test::ExpectTensorNear<double>(tensors[0], tensors_expected[0], 1e-6); +} + +TEST_F(CommonSubgraphEliminationTest, OpDedupCommutative) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output c1 = ops::Const(s.WithOpName("c1"), {1.0f, 2.0f}, {1, 2}); + Output c2 = ops::Const(s.WithOpName("c2"), {3.0f, 4.0f}, {1, 2}); + Output mul1 = ops::Mul(s.WithOpName("mul1"), c1, c2); + Output mul2 = ops::Mul(s.WithOpName("mul2"), c2, c1); + Output div1 = ops::Div(s.WithOpName("div1"), mul1, mul2); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch = {"div1"}; + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + ASSERT_EQ(tensors_expected.size(), 1); + + CommonSubgraphElimination optimizer; + GraphDef output; + OptimizeTwice(&optimizer, &item, &output); + NodeMap node_map(&output); + + EXPECT_EQ(output.node_size(), 4); + const NodeDef* new_c1 = node_map.GetNode("c1"); + ASSERT_NE(new_c1, nullptr); + const NodeDef* new_c2 = node_map.GetNode("c2"); + ASSERT_NE(new_c2, nullptr); + const NodeDef* new_mul1 = node_map.GetNode("mul1"); + ASSERT_NE(new_mul1, nullptr); + ASSERT_EQ(new_mul1->input_size(), 2); + EXPECT_EQ(new_mul1->input(0), "c1"); + EXPECT_EQ(new_mul1->input(1), "c2"); + const NodeDef* new_div1 = node_map.GetNode("div1"); + ASSERT_NE(new_div1, nullptr); + ASSERT_EQ(new_div1->input_size(), 2); + EXPECT_EQ(new_div1->input(0), "mul1"); + EXPECT_EQ(new_div1->input(1), "mul1"); + + auto tensors = EvaluateNodes(output, item.fetch); + ASSERT_EQ(tensors.size(), 1); + test::ExpectTensorNear<float>(tensors[0], tensors_expected[0], 1e-6); +} + +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 3cf20ca7dab..77b2caf87b0 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -383,7 +383,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { op != "TensorArraySizeV3") { continue; } - const std::vector<OpInfo::TensorProperties>& output = properties.GetOutputProperties(node->name()); const std::vector<OpInfo::TensorProperties>& input = @@ -410,8 +409,16 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { continue; } + // TODO(rmlarsen): Remove this workaround for b/150861569 + // The bug involves an expression of the form Shape(ExpandDims(x) + // with an incorrectly inferred zero-size first dimension. + if (op == "Shape") { + if (shape.dims() > 0 && shape.dim_size(0) == 0) continue; + } + // Repurpose the existing node to be the constant. // Device placement is preserved. + graph_modified_ = true; node->set_op("Const"); node->clear_attr(); (*node->mutable_attr())["dtype"].set_type(type); @@ -424,9 +431,8 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // the original graph. string ctrl_dep = AddControlDependency(node->input(0), graph_, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(0), ctrl_dep); node->set_input(0, ctrl_dep); - node_map_->AddOutput(NodeName(ctrl_dep), node->name()); - // Done with the Shape/Size/Rank node, move to the next node. continue; } @@ -458,6 +464,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { continue; } + graph_modified_ = true; node->set_op("Const"); *node->mutable_attr() = array_size->attr(); node->set_input(0, AsControlDependency(NodeName(node->input(0)))); @@ -519,6 +526,7 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { } *output->mutable_input(k) = const_name; node_map_->AddOutput(const_name, output->name()); + graph_modified_ = true; } if (node_name == shape_n_node->name() && port != port_idx) { direct_edges_exist = true; @@ -3705,8 +3713,9 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes( } Status ConstantFolding::RunOptimizationPass(Cluster* cluster, - const GrapplerItem& item, + GrapplerItem* item, GraphDef* optimized_graph) { + graph_ = &item->graph; node_map_.reset(new NodeMap(graph_)); nodes_whitelist_.clear(); // Fold fetch nodes iff it has a single fanout. Note that if a fetch node @@ -3716,14 +3725,14 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, // replace the node with multiple constants (each for one fanout) with // new names, and as a result users would not be able to fetch the node any // more with the original node name. - for (const auto& fetch : item.fetch) { + for (const auto& fetch : item->fetch) { const NodeDef* fetch_node = node_map_->GetNode(fetch); if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) { nodes_whitelist_.insert(fetch_node->name()); } } - GraphProperties properties(item); + GraphProperties properties(*item); // It's possible to feed a placeholder with a tensor of any shape: make sure // that the shape inference deals with this conservatively unless we're in // aggressive mode. @@ -3732,15 +3741,18 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, /*aggressive_shape_inference=*/false, /*include_input_tensor_values=*/false, /*include_output_tensor_values=*/true); + const bool can_use_shape_info = s.ok(); + absl::flat_hash_set<string> nodes_to_not_simplify; if (can_use_shape_info) { TF_RETURN_IF_ERROR(MaterializeShapes(properties)); TF_RETURN_IF_ERROR(MaterializeConstants(properties)); + TF_RETURN_IF_ERROR( + FoldGraph(properties, optimized_graph, &nodes_to_not_simplify)); + } else { + *optimized_graph = *graph_; } - absl::flat_hash_set<string> nodes_to_not_simplify; - TF_RETURN_IF_ERROR( - FoldGraph(properties, optimized_graph, &nodes_to_not_simplify)); node_map_.reset(new NodeMap(optimized_graph)); TF_RETURN_IF_ERROR(SimplifyGraph(can_use_shape_info, optimized_graph, &properties, &nodes_to_not_simplify)); @@ -3795,11 +3807,10 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item, GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); graph_modified_ = false; item_to_optimize.graph.Swap(optimized_graph); - graph_ = &item_to_optimize.graph; - *optimized_graph = GraphDef(); - node_count = graph_->node_size(); + optimized_graph->Clear(); + node_count = item_to_optimize.graph.node_size(); TF_RETURN_IF_ERROR( - RunOptimizationPass(cluster, item_to_optimize, optimized_graph)); + RunOptimizationPass(cluster, &item_to_optimize, optimized_graph)); } while (graph_modified_ || optimized_graph->node_size() != node_count); *optimized_graph->mutable_library() = item.graph.library(); *optimized_graph->mutable_versions() = item.graph.versions(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 88c4094bb1a..074f0c5f057 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -131,8 +131,8 @@ class ConstantFolding : public GraphOptimizer { Status SimplifyNode(bool use_shape_info, NodeDef* node, GraphDef* optimized_graph, GraphProperties* properties); - Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, - GraphDef* output); + Status RunOptimizationPass(Cluster* cluster, GrapplerItem* item, + GraphDef* optimized_graph); // Applies partial constant folding for Concat which is not commutative. // Returns true if the transformation applied successfully. diff --git a/tensorflow/core/grappler/optimizers/graph_optimizer.h b/tensorflow/core/grappler/optimizers/graph_optimizer.h index 238606ee673..de678d0a390 100644 --- a/tensorflow/core/grappler/optimizers/graph_optimizer.h +++ b/tensorflow/core/grappler/optimizers/graph_optimizer.h @@ -17,9 +17,12 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_GRAPH_OPTIMIZER_H_ #include <string> + #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace grappler { diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index da83e413ff6..82758d1f970 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h" #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h" #include "tensorflow/core/grappler/optimizers/auto_parallel.h" +#include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/debug_stripper.h" @@ -184,6 +185,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( MK_OPT("auto_mixed_precision", new AutoMixedPrecision(cfg_.auto_mixed_precision())); MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL)); + MK_OPT("common_subgraph_elimination", + new CommonSubgraphElimination(cfg_.common_subgraph_elimination())); MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization())); MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas())); MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_)); @@ -224,6 +227,11 @@ Status MetaOptimizer::InitializeOptimizers( cfg_.function_optimization(), /*lower_control_flow=*/!IsSingleThreadedExecutor())); } + if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF && + cfg_.arithmetic_optimization() != RewriterConfig::OFF) { + optimizers->push_back(MakeUnique<CommonSubgraphElimination>( + cfg_.common_subgraph_elimination())); + } if (cfg_.debug_stripper() == RewriterConfig::ON) { optimizers->push_back(MakeUnique<DebugStripper>()); } @@ -812,6 +820,7 @@ bool MetaOptimizerEnabled(const ConfigProto& cfg) { rewrite_cfg.constant_folding() != RewriterConfig::OFF || rewrite_cfg.shape_optimization() != RewriterConfig::OFF || rewrite_cfg.remapping() != RewriterConfig::OFF || + rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF || rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF || rewrite_cfg.loop_optimization() != RewriterConfig::OFF || rewrite_cfg.dependency_optimization() != RewriterConfig::OFF || diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 87835245762..a50c6f71fee 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -23,6 +23,7 @@ limitations under the License. #include <utility> #include <vector> +#include "absl/container/node_hash_map.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "tensorflow/core/framework/graph.pb.h" @@ -65,8 +66,8 @@ class NodeMap { private: const std::set<NodeDef*> empty_set_; - gtl::FlatMap<string, NodeDef*> nodes_; - gtl::FlatMap<string, std::set<NodeDef*>> outputs_; + absl::node_hash_map<string, NodeDef*> nodes_; + absl::node_hash_map<string, std::set<NodeDef*>> outputs_; }; // A vector with a set. The set stores the same elements as the vector, and diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index e657a184b65..38c3ad7ae57 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -60,6 +60,9 @@ message RewriterConfig { // Remapping (default is ON) // Remap subgraphs onto more efficient implementations. Toggle remapping = 14; + // Common subgraph elimination (default is ON) + // e.g. Simplify arithmetic ops; merge ops with same value (like constants). + Toggle common_subgraph_elimination = 24; // Arithmetic optimizations (default is ON) // e.g. Simplify arithmetic ops; merge ops with same value (like constants). Toggle arithmetic_optimization = 7;