Properly support nest phi reduction in reverse order.

If we replaced node B with node C, then replace node A with node B, we
should redirect node A to node C instead.

PiperOrigin-RevId: 317010443
Change-Id: I165496a3d1f6571815bfd61d096e26cbba39125a
This commit is contained in:
Yunxing Dai 2020-06-17 18:43:41 -07:00 committed by TensorFlower Gardener
parent a50001edf9
commit ae20f08da9
4 changed files with 52 additions and 2 deletions

View File

@ -1007,6 +1007,8 @@ void HloDataflowAnalysis::OptimizePhiValues() {
HloValue::Id phi_id = values[0]->id();
HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
if (new_id != phi_id) {
VLOG(1) << "Replacing " << values[0]->ToString() << " with "
<< GetValue(new_id).ToString();
value_set->Clear();
const HloValue& new_value = GetValue(new_id);
value_set->AddValue(&new_value);

View File

@ -20,10 +20,11 @@ limitations under the License.
namespace xla {
HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) {
Node* node = value_id_to_node_[value.id()];
CHECK(!node->mark_as_dead);
return node->value_id;
}
// Returns true if the input to a hlo value is the same as `inputs`.
// Returns true if the inputs to a hlo value are the same as `inputs`.
bool PhiGraph::InputsEqualTo(const HloValue& value,
absl::Span<const HloValue* const> inputs) {
auto iter = value_id_to_node_.find(value.id());
@ -42,6 +43,7 @@ bool PhiGraph::InputsEqualTo(const HloValue& value,
HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) {
auto iter = value_id_to_node_.find(id);
CHECK(iter != value_id_to_node_.end());
CHECK(!iter->second->mark_as_dead);
return iter->second->value_id;
}
@ -66,6 +68,17 @@ PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) {
void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
// Update users.
CHECK(node->is_phi);
if (node->mark_as_dead) {
// The node has already been replaced with another.
return;
}
if (replace->mark_as_dead) {
// The node we are placing with has already been replaced with another node.
auto iter = value_id_to_node_.find(replace->value_id);
CHECK(iter != value_id_to_node_.end());
return ReplaceNodeWith(node, iter->second);
}
CHECK(!replace->mark_as_dead);
for (Node* user : node->users) {
absl::c_replace(user->operands, node, replace);
}
@ -74,6 +87,7 @@ void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
for (Node* operand : node->operands) {
absl::c_replace(operand->users, node, replace);
}
for (HloValue::Id value_id : node_to_value_id_[node]) {
CHECK(value_id_to_node_.contains(value_id));
value_id_to_node_[value_id] = replace;
@ -115,6 +129,8 @@ std::string PhiGraph::ToString() {
}
void PhiGraph::Optimize() {
VLOG(2) << "Optimizing phi graph:";
XLA_VLOG_LINES(2, ToString());
// Set up users for each node.
for (auto& node : node_storage_) {
for (Node* input : node->operands) {
@ -141,6 +157,8 @@ void PhiGraph::Optimize() {
Node* node_ptr = node.get();
VLOG(2) << "Optimizing: " << node_ptr->value_id;
CHECK_GE(node_ptr->operands.size(), 1);
// Remove self-referencing ids from users and operands.
@ -167,6 +185,9 @@ void PhiGraph::Optimize() {
[&](Node* elem) { return elem == node_ptr->operands[0]; });
if (all_inputs_are_same) {
VLOG(1) << "All inputs to node " << node_ptr->value_id
<< " are the same, replacing it with "
<< node_ptr->operands[0]->value_id;
ReplaceNodeWith(node_ptr, node_ptr->operands[0]);
changed = true;
continue;
@ -223,6 +244,8 @@ void PhiGraph::Optimize() {
CHECK_EQ(node, non_phi);
continue;
}
VLOG(1) << "Replace node " << node->value_id
<< " in the closure with node " << non_phi->value_id;
ReplaceNodeWith(node, non_phi);
changed = true;
}

View File

@ -90,7 +90,7 @@ class PhiGraph {
// to that phi.
absl::flat_hash_map<Node*, std::vector<HloValue::Id>> node_to_value_id_;
// A mapping between a HloValue and node in the phi graph.
// A mapping from a HloValue to node in the phi graph.
absl::flat_hash_map<HloValue::Id, Node*> value_id_to_node_;
std::vector<std::unique_ptr<Node>> node_storage_;
};

View File

@ -82,5 +82,30 @@ TEST_F(PhiGraphTest, CircularPhi) {
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id()));
}
TEST_F(PhiGraphTest, NestedPhiReduction) {
// def A = phi(B, C)
// def B = phi(C, E)
// def C = phi(A, B)
// def D = non-phi
// def E = Phi(D, D)
// 1. Replace E with D
// 2. Replace A B and C with E/D
PhiGraph phi_graph;
HloValue A = NewHloValue(true);
HloValue B = NewHloValue(true);
HloValue C = NewHloValue(true);
HloValue D = NewHloValue(false);
HloValue E = NewHloValue(true);
phi_graph.RegisterPhi(A, {&B, &C});
phi_graph.RegisterPhi(B, {&E, &C});
phi_graph.RegisterPhi(C, {&A, &B});
phi_graph.RegisterPhi(E, {&D, &D});
phi_graph.Optimize();
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(A.id()));
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(B.id()));
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id()));
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(E.id()));
}
} // namespace
} // namespace xla