Properly support nest phi reduction in reverse order.
If we replaced node B with node C, then replace node A with node B, we should redirect node A to node C instead. PiperOrigin-RevId: 317010443 Change-Id: I165496a3d1f6571815bfd61d096e26cbba39125a
This commit is contained in:
parent
a50001edf9
commit
ae20f08da9
@ -1007,6 +1007,8 @@ void HloDataflowAnalysis::OptimizePhiValues() {
|
||||
HloValue::Id phi_id = values[0]->id();
|
||||
HloValue::Id new_id = phi_graph_.FindOptimizedValue(phi_id);
|
||||
if (new_id != phi_id) {
|
||||
VLOG(1) << "Replacing " << values[0]->ToString() << " with "
|
||||
<< GetValue(new_id).ToString();
|
||||
value_set->Clear();
|
||||
const HloValue& new_value = GetValue(new_id);
|
||||
value_set->AddValue(&new_value);
|
||||
|
@ -20,10 +20,11 @@ limitations under the License.
|
||||
namespace xla {
|
||||
HloValue::Id PhiGraph::GetOptimizedId(const HloValue& value) {
|
||||
Node* node = value_id_to_node_[value.id()];
|
||||
CHECK(!node->mark_as_dead);
|
||||
return node->value_id;
|
||||
}
|
||||
|
||||
// Returns true if the input to a hlo value is the same as `inputs`.
|
||||
// Returns true if the inputs to a hlo value are the same as `inputs`.
|
||||
bool PhiGraph::InputsEqualTo(const HloValue& value,
|
||||
absl::Span<const HloValue* const> inputs) {
|
||||
auto iter = value_id_to_node_.find(value.id());
|
||||
@ -42,6 +43,7 @@ bool PhiGraph::InputsEqualTo(const HloValue& value,
|
||||
HloValue::Id PhiGraph::FindOptimizedValue(const HloValue::Id id) {
|
||||
auto iter = value_id_to_node_.find(id);
|
||||
CHECK(iter != value_id_to_node_.end());
|
||||
CHECK(!iter->second->mark_as_dead);
|
||||
return iter->second->value_id;
|
||||
}
|
||||
|
||||
@ -66,6 +68,17 @@ PhiGraph::Node* PhiGraph::CreateOrReuseNode(const HloValue& value) {
|
||||
void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
|
||||
// Update users.
|
||||
CHECK(node->is_phi);
|
||||
if (node->mark_as_dead) {
|
||||
// The node has already been replaced with another.
|
||||
return;
|
||||
}
|
||||
if (replace->mark_as_dead) {
|
||||
// The node we are placing with has already been replaced with another node.
|
||||
auto iter = value_id_to_node_.find(replace->value_id);
|
||||
CHECK(iter != value_id_to_node_.end());
|
||||
return ReplaceNodeWith(node, iter->second);
|
||||
}
|
||||
CHECK(!replace->mark_as_dead);
|
||||
for (Node* user : node->users) {
|
||||
absl::c_replace(user->operands, node, replace);
|
||||
}
|
||||
@ -74,6 +87,7 @@ void PhiGraph::ReplaceNodeWith(PhiGraph::Node* node, PhiGraph::Node* replace) {
|
||||
for (Node* operand : node->operands) {
|
||||
absl::c_replace(operand->users, node, replace);
|
||||
}
|
||||
|
||||
for (HloValue::Id value_id : node_to_value_id_[node]) {
|
||||
CHECK(value_id_to_node_.contains(value_id));
|
||||
value_id_to_node_[value_id] = replace;
|
||||
@ -115,6 +129,8 @@ std::string PhiGraph::ToString() {
|
||||
}
|
||||
|
||||
void PhiGraph::Optimize() {
|
||||
VLOG(2) << "Optimizing phi graph:";
|
||||
XLA_VLOG_LINES(2, ToString());
|
||||
// Set up users for each node.
|
||||
for (auto& node : node_storage_) {
|
||||
for (Node* input : node->operands) {
|
||||
@ -141,6 +157,8 @@ void PhiGraph::Optimize() {
|
||||
|
||||
Node* node_ptr = node.get();
|
||||
|
||||
VLOG(2) << "Optimizing: " << node_ptr->value_id;
|
||||
|
||||
CHECK_GE(node_ptr->operands.size(), 1);
|
||||
|
||||
// Remove self-referencing ids from users and operands.
|
||||
@ -167,6 +185,9 @@ void PhiGraph::Optimize() {
|
||||
[&](Node* elem) { return elem == node_ptr->operands[0]; });
|
||||
|
||||
if (all_inputs_are_same) {
|
||||
VLOG(1) << "All inputs to node " << node_ptr->value_id
|
||||
<< " are the same, replacing it with "
|
||||
<< node_ptr->operands[0]->value_id;
|
||||
ReplaceNodeWith(node_ptr, node_ptr->operands[0]);
|
||||
changed = true;
|
||||
continue;
|
||||
@ -223,6 +244,8 @@ void PhiGraph::Optimize() {
|
||||
CHECK_EQ(node, non_phi);
|
||||
continue;
|
||||
}
|
||||
VLOG(1) << "Replace node " << node->value_id
|
||||
<< " in the closure with node " << non_phi->value_id;
|
||||
ReplaceNodeWith(node, non_phi);
|
||||
changed = true;
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ class PhiGraph {
|
||||
// to that phi.
|
||||
absl::flat_hash_map<Node*, std::vector<HloValue::Id>> node_to_value_id_;
|
||||
|
||||
// A mapping between a HloValue and node in the phi graph.
|
||||
// A mapping from a HloValue to node in the phi graph.
|
||||
absl::flat_hash_map<HloValue::Id, Node*> value_id_to_node_;
|
||||
std::vector<std::unique_ptr<Node>> node_storage_;
|
||||
};
|
||||
|
@ -82,5 +82,30 @@ TEST_F(PhiGraphTest, CircularPhi) {
|
||||
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id()));
|
||||
}
|
||||
|
||||
TEST_F(PhiGraphTest, NestedPhiReduction) {
|
||||
// def A = phi(B, C)
|
||||
// def B = phi(C, E)
|
||||
// def C = phi(A, B)
|
||||
// def D = non-phi
|
||||
// def E = Phi(D, D)
|
||||
// 1. Replace E with D
|
||||
// 2. Replace A B and C with E/D
|
||||
PhiGraph phi_graph;
|
||||
HloValue A = NewHloValue(true);
|
||||
HloValue B = NewHloValue(true);
|
||||
HloValue C = NewHloValue(true);
|
||||
HloValue D = NewHloValue(false);
|
||||
HloValue E = NewHloValue(true);
|
||||
phi_graph.RegisterPhi(A, {&B, &C});
|
||||
phi_graph.RegisterPhi(B, {&E, &C});
|
||||
phi_graph.RegisterPhi(C, {&A, &B});
|
||||
phi_graph.RegisterPhi(E, {&D, &D});
|
||||
phi_graph.Optimize();
|
||||
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(A.id()));
|
||||
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(B.id()));
|
||||
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(C.id()));
|
||||
EXPECT_EQ(D.id(), phi_graph.FindOptimizedValue(E.id()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user