From ae20f08da9ce9e7336ab97cc9f77ce7a1c13ad12 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Wed, 17 Jun 2020 18:43:41 -0700 Subject: [PATCH] Properly support nest phi reduction in reverse order. If we replaced node B with node C, then replace node A with node B, we should redirect node A to node C instead. PiperOrigin-RevId: 317010443 Change-Id: I165496a3d1f6571815bfd61d096e26cbba39125a --- .../xla/service/hlo_dataflow_analysis.cc | 2 ++ .../compiler/xla/service/hlo_phi_graph.cc | 25 ++++++++++++++++++- .../compiler/xla/service/hlo_phi_graph.h | 2 +- .../xla/service/hlo_phi_graph_test.cc | 25 +++++++++++++++++++ 4 files changed, 52 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index f19882c9347..a46d20d5808 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -1007,6 +1007,8 @@ void HloDataflowAnalysis::OptimizePhiValues() { HloValue::Id phi_id = values[0]->id(); HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id); if (new_id != phi_id) { + VLOG(1) << "Replacing " << values[0]->ToString() << " with " + << GetValue(new_id).ToString(); value_set->Clear(); const HloValue& new_value = GetValue(new_id); value_set->AddValue(&new_value); diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.cc b/tensorflow/compiler/xla/service/hlo_phi_graph.cc index 9b69771dab2..a2cba3d1bff 100644 --- a/tensorflow/compiler/xla/service/hlo_phi_graph.cc +++ b/tensorflow/compiler/xla/service/hlo_phi_graph.cc @@ -20,10 +20,11 @@ limitations under the License. namespace xla { HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) { Node* node = value_id_to_node_[value.id()]; + CHECK(!node->mark_as_dead); return node->value_id; } -// Returns true if the input to a hlo value is the same as `inputs`. +// Returns true if the inputs to a hlo value are the same as `inputs`. bool PhiGraph::InputsEqualTo(const HloValue& value, absl::Span inputs) { auto iter = value_id_to_node_.find(value.id()); @@ -42,6 +43,7 @@ bool PhiGraph::InputsEqualTo(const HloValue& value, HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) { auto iter = value_id_to_node_.find(id); CHECK(iter != value_id_to_node_.end()); + CHECK(!iter->second->mark_as_dead); return iter->second->value_id; } @@ -66,6 +68,17 @@ PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) { void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) { // Update users. CHECK(node->is_phi); + if (node->mark_as_dead) { + // The node has already been replaced with another. + return; + } + if (replace->mark_as_dead) { + // The node we are placing with has already been replaced with another node. + auto iter = value_id_to_node_.find(replace->value_id); + CHECK(iter != value_id_to_node_.end()); + return ReplaceNodeWith(node, iter->second); + } + CHECK(!replace->mark_as_dead); for (Node* user : node->users) { absl::c_replace(user->operands, node, replace); } @@ -74,6 +87,7 @@ void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) { for (Node* operand : node->operands) { absl::c_replace(operand->users, node, replace); } + for (HloValue::Id value_id : node_to_value_id_[node]) { CHECK(value_id_to_node_.contains(value_id)); value_id_to_node_[value_id] = replace; @@ -115,6 +129,8 @@ std::string PhiGraph::ToString() { } void PhiGraph::Optimize() { + VLOG(2) << "Optimizing phi graph:"; + XLA_VLOG_LINES(2, ToString()); // Set up users for each node. for (auto& node : node_storage_) { for (Node* input : node->operands) { @@ -141,6 +157,8 @@ void PhiGraph::Optimize() { Node* node_ptr = node.get(); + VLOG(2) << "Optimizing: " << node_ptr->value_id; + CHECK_GE(node_ptr->operands.size(), 1); // Remove self-referencing ids from users and operands. @@ -167,6 +185,9 @@ void PhiGraph::Optimize() { [&](Node* elem) { return elem == node_ptr->operands[0]; }); if (all_inputs_are_same) { + VLOG(1) << "All inputs to node " << node_ptr->value_id + << " are the same, replacing it with " + << node_ptr->operands[0]->value_id; ReplaceNodeWith(node_ptr, node_ptr->operands[0]); changed = true; continue; @@ -223,6 +244,8 @@ void PhiGraph::Optimize() { CHECK_EQ(node, non_phi); continue; } + VLOG(1) << "Replace node " << node->value_id + << " in the closure with node " << non_phi->value_id; ReplaceNodeWith(node, non_phi); changed = true; } diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph.h b/tensorflow/compiler/xla/service/hlo_phi_graph.h index a0eb994438e..ca0d5c5009c 100644 --- a/tensorflow/compiler/xla/service/hlo_phi_graph.h +++ b/tensorflow/compiler/xla/service/hlo_phi_graph.h @@ -90,7 +90,7 @@ class PhiGraph { // to that phi. absl::flat_hash_map> node_to_value_id_; - // A mapping between a HloValue and node in the phi graph. + // A mapping from a HloValue to node in the phi graph. absl::flat_hash_map value_id_to_node_; std::vector> node_storage_; }; diff --git a/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc index 41f0454fe55..ee7300b160b 100644 --- a/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc +++ b/tensorflow/compiler/xla/service/hlo_phi_graph_test.cc @@ -82,5 +82,30 @@ TEST_F(PhiGraphTest, CircularPhi) { EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id())); } +TEST_F(PhiGraphTest, NestedPhiReduction) { + // def A = phi(B, C) + // def B = phi(C, E) + // def C = phi(A, B) + // def D = non-phi + // def E = Phi(D, D) + // 1. Replace E with D + // 2. Replace A B and C with E/D + PhiGraph phi_graph; + HloValue A = NewHloValue(true); + HloValue B = NewHloValue(true); + HloValue C = NewHloValue(true); + HloValue D = NewHloValue(false); + HloValue E = NewHloValue(true); + phi_graph.RegisterPhi(A, {&B, &C}); + phi_graph.RegisterPhi(B, {&E, &C}); + phi_graph.RegisterPhi(C, {&A, &B}); + phi_graph.RegisterPhi(E, {&D, &D}); + phi_graph.Optimize(); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(A.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(B.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id())); + EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(E.id())); +} + } // namespace } // namespace xla