Rolling forward after fixing C++ version problems.
Original public description: Add a WhileLoopAllReduceCodeMotion pass, which finds the all-reduce-then-accumuate pattern inside while loops and moves the all-reduce to the outside of the while loop. PiperOrigin-RevId: 349594327 Change-Id: I6f2b6c67d9b89d257585085f97a7c022298d5d1e
This commit is contained in:
parent
64c87d65ff
commit
f8975c3e43
@ -4665,6 +4665,48 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_all_reduce_code_motion",
|
||||
srcs = ["while_loop_all_reduce_code_motion.cc"],
|
||||
hdrs = ["while_loop_all_reduce_code_motion.h"],
|
||||
deps = [
|
||||
":call_graph",
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_pass",
|
||||
":hlo_query",
|
||||
":pattern_matcher",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "while_loop_all_reduce_code_motion_test",
|
||||
srcs = ["while_loop_all_reduce_code_motion_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_matchers",
|
||||
":hlo_verifier",
|
||||
":while_loop_all_reduce_code_motion",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_invariant_code_motion",
|
||||
srcs = ["while_loop_invariant_code_motion.cc"],
|
||||
|
@ -0,0 +1,600 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h"
|
||||
|
||||
#include <iterator>
|
||||
#include <stack>
|
||||
#include <tuple>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/call_graph.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
|
||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
struct AccumulationContext {
|
||||
HloInstruction* accumulation_instruction;
|
||||
HloInstruction* accumulation_buffer;
|
||||
std::vector<int> param_tuple_indices;
|
||||
};
|
||||
|
||||
// Describes whether an all-reduce instruction can be sinked from a while body
|
||||
// computation and all the accumulation uses of the all-reduce's result in the
|
||||
// while body if movable.
|
||||
struct MovableAllReduceContext {
|
||||
bool is_movable;
|
||||
// If movable, `accumulation_contexts` contains one accumulation
|
||||
// context for each accumulation in the while body that uses the all-reduce's
|
||||
// result. Otherwise, this field is undefined.
|
||||
std::vector<AccumulationContext> accumulation_contexts;
|
||||
};
|
||||
|
||||
// Checks if an all-reduce instruction is eligible for sinking and finds all of
|
||||
// the all-reduce's accumulation uses inside the while body if eligible.
|
||||
MovableAllReduceContext IsAllReduceMovable(HloInstruction* all_reduce,
|
||||
HloComputation* while_body) {
|
||||
auto all_reduce_is_summation = [](HloInstruction* all_reduce) -> bool {
|
||||
HloInstruction* to_apply_root = all_reduce->to_apply()->root_instruction();
|
||||
if (all_reduce->to_apply()->num_parameters() != 2) {
|
||||
return false;
|
||||
}
|
||||
return Match(to_apply_root,
|
||||
match::AddAnyOrder(match::Parameter(0), match::Parameter(1)));
|
||||
};
|
||||
|
||||
// We only support numerical types.
|
||||
const absl::InlinedVector<PrimitiveType, 12> kSupportedTypes{
|
||||
BF16, F16, F32, F64, S8, S16, S32, S64, U8, U16, U32, U64};
|
||||
|
||||
if (!absl::c_linear_search(kSupportedTypes,
|
||||
all_reduce->shape().element_type()) ||
|
||||
!all_reduce_is_summation(all_reduce)) {
|
||||
return MovableAllReduceContext{/*is_movable=*/false,
|
||||
/*accumulation_contexts=*/{}};
|
||||
}
|
||||
|
||||
struct BufferTupleIndex {
|
||||
bool unsupported_operation{false};
|
||||
std::vector<int> tuple_index;
|
||||
bool returned_from_computation{false};
|
||||
};
|
||||
|
||||
// If the instruction is a buffer forwarded from a tuple element of the
|
||||
// computation's parameter, returns the indices of the buffer in the parameter
|
||||
// tuple. The returned_from_computation field in the result is unused.
|
||||
auto get_origin_tuple_index =
|
||||
[](HloInstruction* instruction) -> BufferTupleIndex {
|
||||
// The returned_from_computation is never touched in this function.
|
||||
BufferTupleIndex result;
|
||||
while (!result.unsupported_operation) {
|
||||
switch (instruction->opcode()) {
|
||||
default:
|
||||
result.unsupported_operation = true;
|
||||
break;
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kDynamicReshape:
|
||||
instruction = instruction->mutable_operand(0);
|
||||
break;
|
||||
case HloOpcode::kGetTupleElement: {
|
||||
if (!result.tuple_index.empty()) {
|
||||
// Note that we don't support nested tuples as of now.
|
||||
result.unsupported_operation = true;
|
||||
} else {
|
||||
result.tuple_index.push_back(
|
||||
Cast<HloGetTupleElementInstruction>(instruction)
|
||||
->tuple_index());
|
||||
instruction = instruction->mutable_operand(0);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kParameter: {
|
||||
int parameter_number =
|
||||
Cast<HloParameterInstruction>(instruction)->parameter_number();
|
||||
CHECK_EQ(parameter_number, 0);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (instruction->opcode() == HloOpcode::kParameter) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// If the instruction's result is returned from its parent computation with
|
||||
// only forwarding operations, returns the index of the result buffer in the
|
||||
// output parameter tuple.
|
||||
auto get_output_tuple_index =
|
||||
[](HloInstruction* instruction,
|
||||
HloComputation* while_body) -> BufferTupleIndex {
|
||||
BufferTupleIndex result;
|
||||
std::stack<HloInstruction*> to_visit;
|
||||
to_visit.push(instruction);
|
||||
while (!to_visit.empty() && !result.unsupported_operation) {
|
||||
HloInstruction* instruction = to_visit.top();
|
||||
to_visit.pop();
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kSlice: {
|
||||
to_visit.push(user);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDynamicSlice: {
|
||||
if (user->operand_index(instruction) == 0) {
|
||||
to_visit.push(user);
|
||||
} else {
|
||||
result.unsupported_operation = true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kTuple: {
|
||||
if (!result.tuple_index.empty()) {
|
||||
// Note that we don't support nested tuples as of now.
|
||||
result.unsupported_operation = true;
|
||||
} else {
|
||||
result.tuple_index.push_back(user->operand_index(instruction));
|
||||
if (while_body->root_instruction() == user) {
|
||||
if (result.returned_from_computation) {
|
||||
result.unsupported_operation = true;
|
||||
}
|
||||
result.returned_from_computation = true;
|
||||
} else {
|
||||
to_visit.push(user);
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
result.unsupported_operation = true;
|
||||
}
|
||||
if (result.unsupported_operation) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
// Checks whether any buffer in the list of accumulation contexts is used in
|
||||
// the parent computation except for forwarding uses.
|
||||
auto is_buffer_used =
|
||||
[](absl::Span<const AccumulationContext> accumulation_contexts,
|
||||
HloComputation* while_body_computation) -> bool {
|
||||
std::vector<HloInstruction*> parameter_instructions;
|
||||
absl::c_copy_if(while_body_computation->instructions(),
|
||||
std::back_inserter(parameter_instructions),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return instruction->opcode() == HloOpcode::kParameter;
|
||||
});
|
||||
for (const auto& accumulation : accumulation_contexts) {
|
||||
HloInstruction* accumulation_instruction =
|
||||
accumulation.accumulation_instruction;
|
||||
int tuple_index = accumulation.param_tuple_indices[0];
|
||||
std::stack<HloInstruction*> to_visit;
|
||||
// TODO(b/176437845): simplify the logic below by using
|
||||
// TuplePointsToAnalysis.
|
||||
for (HloInstruction* parameter_instruction : parameter_instructions) {
|
||||
// Iterate over all users of the while body parameter and find all
|
||||
// instructions that use the accumulation buffer, as specified by
|
||||
// tuple_index.
|
||||
// This logic could be simplied by using TuplePointsToAnalysis, which
|
||||
// we leave to a future CL (see TODO above).
|
||||
for (HloInstruction* user : parameter_instruction->users()) {
|
||||
if (auto* gte = DynCast<HloGetTupleElementInstruction>(user)) {
|
||||
if (gte->tuple_index() == tuple_index) {
|
||||
to_visit.push(user);
|
||||
}
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
while (!to_visit.empty()) {
|
||||
HloInstruction* instruction = to_visit.top();
|
||||
to_visit.pop();
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kTranspose:
|
||||
to_visit.push(user);
|
||||
break;
|
||||
case HloOpcode::kDynamicReshape: {
|
||||
if (instruction == user->operand(0)) {
|
||||
to_visit.push(user);
|
||||
} else {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kAdd: {
|
||||
if (user != accumulation_instruction) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// Finds all accumulation contexts of the given all-reduce instruction if it
|
||||
// is movable.
|
||||
auto get_accumulation_contexts =
|
||||
[&get_origin_tuple_index, &get_output_tuple_index, &is_buffer_used](
|
||||
HloInstruction* all_reduce,
|
||||
HloComputation* while_body) -> MovableAllReduceContext {
|
||||
std::vector<AccumulationContext> accumulation_contexts;
|
||||
// DFS starting from the all-reduce instruction and stops at the first
|
||||
// non-triival uses of the all-reduce result or finds all accmululations
|
||||
// of the all-reduce result.
|
||||
std::stack<HloInstruction*> to_visit;
|
||||
// By default movable unless we find that it's not.
|
||||
bool is_all_reduce_movable = true;
|
||||
to_visit.push(all_reduce);
|
||||
|
||||
while (!to_visit.empty() && is_all_reduce_movable) {
|
||||
HloInstruction* instruction = to_visit.top();
|
||||
to_visit.pop();
|
||||
for (HloInstruction* user : instruction->users()) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kBitcast:
|
||||
case HloOpcode::kConvert:
|
||||
case HloOpcode::kReshape:
|
||||
case HloOpcode::kGetTupleElement:
|
||||
case HloOpcode::kTranspose:
|
||||
case HloOpcode::kSlice: {
|
||||
to_visit.push(user);
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kDynamicSlice: {
|
||||
if (user->operand_index(instruction) == 0) {
|
||||
to_visit.push(user);
|
||||
} else {
|
||||
is_all_reduce_movable = false;
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kAdd: {
|
||||
int64 buffer_index = 1 - user->operand_index(instruction);
|
||||
HloInstruction* accumulation_buffer =
|
||||
user->mutable_operand(buffer_index);
|
||||
|
||||
auto origin_buffer_tuple_index =
|
||||
get_origin_tuple_index(accumulation_buffer);
|
||||
if (origin_buffer_tuple_index.unsupported_operation) {
|
||||
is_all_reduce_movable = false;
|
||||
break;
|
||||
}
|
||||
|
||||
auto output_buffer_tuple_index =
|
||||
get_output_tuple_index(user, while_body);
|
||||
if (!output_buffer_tuple_index.unsupported_operation &&
|
||||
output_buffer_tuple_index.returned_from_computation &&
|
||||
!origin_buffer_tuple_index.tuple_index.empty() &&
|
||||
ContainersEqual(origin_buffer_tuple_index.tuple_index,
|
||||
output_buffer_tuple_index.tuple_index)) {
|
||||
accumulation_contexts.push_back(AccumulationContext{
|
||||
user, accumulation_buffer,
|
||||
std::move(output_buffer_tuple_index.tuple_index)});
|
||||
} else {
|
||||
is_all_reduce_movable = false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default:
|
||||
is_all_reduce_movable = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (is_buffer_used(accumulation_contexts, while_body)) {
|
||||
is_all_reduce_movable = false;
|
||||
}
|
||||
return MovableAllReduceContext{is_all_reduce_movable,
|
||||
accumulation_contexts};
|
||||
};
|
||||
return get_accumulation_contexts(all_reduce, while_body);
|
||||
}
|
||||
|
||||
struct WhileInitContext {
|
||||
HloInstruction* while_init{nullptr};
|
||||
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_old_buffer;
|
||||
};
|
||||
|
||||
// Creates a new while init instruction, which replaces each accumulation buffer
|
||||
// in the given accumulation contexts with a zero-initialized buffer. In other
|
||||
// words, we are accumulating all the deltas in the while loop with a zero
|
||||
// initial value.
|
||||
WhileInitContext CreateNewWhileInit(
|
||||
HloInstruction* old_while_instruction,
|
||||
const HloInstructionMap<std::vector<AccumulationContext>>&
|
||||
all_reduce_to_accumulations) {
|
||||
HloInstruction* old_while_init = old_while_instruction->mutable_operand(0);
|
||||
HloComputation* while_parent = old_while_instruction->parent();
|
||||
std::vector<HloInstruction*> new_while_init_elements(
|
||||
old_while_init->operand_count(), nullptr);
|
||||
for (const auto& all_reduce_and_accumulations_pair :
|
||||
all_reduce_to_accumulations) {
|
||||
const std::vector<AccumulationContext>& accumulations =
|
||||
all_reduce_and_accumulations_pair.second;
|
||||
for (auto& accumulation_context : accumulations) {
|
||||
CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1);
|
||||
int tuple_index = accumulation_context.param_tuple_indices[0];
|
||||
HloInstruction* old_buffer = old_while_init->mutable_operand(tuple_index);
|
||||
HloInstruction* new_buffer = while_parent->AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateFromDimensions(
|
||||
old_buffer->shape().element_type(),
|
||||
old_buffer->shape().dimensions())));
|
||||
new_while_init_elements[tuple_index] = new_buffer;
|
||||
}
|
||||
}
|
||||
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_old_buffer;
|
||||
for (int i = 0; i < old_while_init->operand_count(); i++) {
|
||||
if (!new_while_init_elements[i]) {
|
||||
new_while_init_elements[i] = old_while_init->mutable_operand(i);
|
||||
} else {
|
||||
tuple_index_to_old_buffer[i] = old_while_init->mutable_operand(i);
|
||||
}
|
||||
}
|
||||
HloInstruction* new_while_init = while_parent->AddInstruction(
|
||||
HloInstruction::CreateTuple(new_while_init_elements));
|
||||
return WhileInitContext{new_while_init, tuple_index_to_old_buffer};
|
||||
}
|
||||
|
||||
// Creates all the sinked all-reduce instructions in the while instruction's
|
||||
// parent computation. Returns a map that maps a tuple index of an accumulation
|
||||
// buffer to it's corresponding all-reduce.
|
||||
absl::flat_hash_map<int, HloInstruction*> CreateSinkedAllReduces(
|
||||
HloInstruction* new_while_instruction,
|
||||
const HloInstructionMap<std::vector<AccumulationContext>>&
|
||||
all_reduce_to_accumulations,
|
||||
const absl::flat_hash_map<int, HloInstruction*>&
|
||||
tuple_index_to_old_buffer) {
|
||||
HloComputation* while_parent = new_while_instruction->parent();
|
||||
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_new_buffer;
|
||||
for (const auto& all_reduce_and_accumulations_pair :
|
||||
all_reduce_to_accumulations) {
|
||||
HloInstruction* loop_all_reduce = all_reduce_and_accumulations_pair.first;
|
||||
const std::vector<AccumulationContext>& accumulations =
|
||||
all_reduce_and_accumulations_pair.second;
|
||||
for (const auto& accumulation_context : accumulations) {
|
||||
CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1);
|
||||
int tuple_index = accumulation_context.param_tuple_indices[0];
|
||||
const Shape& accumulation_buffer_shape =
|
||||
new_while_instruction->shape().tuple_shapes(tuple_index);
|
||||
HloInstruction* accumulation_buffer =
|
||||
while_parent->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
accumulation_buffer_shape, new_while_instruction, tuple_index));
|
||||
HloAllReduceInstruction* old_all_reduce =
|
||||
Cast<HloAllReduceInstruction>(loop_all_reduce);
|
||||
HloInstruction* all_reduce_operand = accumulation_buffer;
|
||||
if (!ShapeUtil::SameElementType(old_all_reduce->shape(),
|
||||
accumulation_buffer_shape)) {
|
||||
Shape all_reduce_shape =
|
||||
ShapeUtil::MakeShape(old_all_reduce->shape().element_type(),
|
||||
accumulation_buffer_shape.dimensions());
|
||||
all_reduce_operand =
|
||||
while_parent->AddInstruction(HloInstruction::CreateConvert(
|
||||
all_reduce_shape, accumulation_buffer));
|
||||
}
|
||||
HloInstruction* new_all_reduce =
|
||||
while_parent->AddInstruction(HloInstruction::CreateAllReduce(
|
||||
all_reduce_operand->shape(), {all_reduce_operand},
|
||||
old_all_reduce->called_computations()[0],
|
||||
old_all_reduce->replica_groups(),
|
||||
old_all_reduce->constrain_layout(),
|
||||
hlo_query::NextChannelId(*(while_parent->parent())),
|
||||
old_all_reduce->use_global_device_ids()));
|
||||
HloInstruction* all_reduced_delta = new_all_reduce;
|
||||
if (!ShapeUtil::SameElementType(all_reduced_delta->shape(),
|
||||
accumulation_buffer_shape)) {
|
||||
all_reduced_delta =
|
||||
while_parent->AddInstruction(HloInstruction::CreateConvert(
|
||||
accumulation_buffer_shape, all_reduced_delta));
|
||||
}
|
||||
CHECK(ContainsKey(tuple_index_to_old_buffer, tuple_index));
|
||||
HloInstruction* old_buffer = tuple_index_to_old_buffer.at(tuple_index);
|
||||
CHECK(ShapeUtil::Equal(old_buffer->shape(), all_reduced_delta->shape()));
|
||||
HloInstruction* add_to_old_buffer =
|
||||
while_parent->AddInstruction(HloInstruction::CreateBinary(
|
||||
all_reduced_delta->shape(), HloOpcode::kAdd, old_buffer,
|
||||
all_reduced_delta));
|
||||
tuple_index_to_new_buffer[tuple_index] = add_to_old_buffer;
|
||||
}
|
||||
}
|
||||
return tuple_index_to_new_buffer;
|
||||
}
|
||||
|
||||
// Creates a tuple which is equivalent to the original while instruction's
|
||||
// output.
|
||||
HloInstruction* CreateNewWhileResult(
|
||||
HloInstruction* new_while_instruction,
|
||||
const absl::flat_hash_map<int, HloInstruction*>&
|
||||
tuple_index_to_new_buffer) {
|
||||
HloComputation* while_parent = new_while_instruction->parent();
|
||||
CHECK(new_while_instruction->shape().IsTuple());
|
||||
std::vector<HloInstruction*> new_while_result_elements(
|
||||
new_while_instruction->shape().tuple_shapes_size(), nullptr);
|
||||
for (int i = 0; i < new_while_result_elements.size(); i++) {
|
||||
if (ContainsKey(tuple_index_to_new_buffer, i)) {
|
||||
new_while_result_elements[i] = tuple_index_to_new_buffer.at(i);
|
||||
} else {
|
||||
HloInstruction* gte =
|
||||
while_parent->AddInstruction(HloInstruction::CreateGetTupleElement(
|
||||
new_while_instruction->shape().tuple_shapes(i),
|
||||
new_while_instruction, i));
|
||||
new_while_result_elements[i] = gte;
|
||||
}
|
||||
}
|
||||
HloInstruction* new_while_result = while_parent->AddInstruction(
|
||||
HloInstruction::CreateTuple(new_while_result_elements));
|
||||
return new_while_result;
|
||||
}
|
||||
|
||||
// Creates the sinked all-reduce instructions for all accumulation buffers. The
|
||||
// all-reduce outputs are then added to the original accumulation buffers.
|
||||
// Creates a tuple that groups the while loop output and the accumulated
|
||||
// buffers and replaces all uses of the old while with this new tuple.
|
||||
Status AddSinkedAllReducesAndReplaceWhile(
|
||||
HloInstruction* while_instruction,
|
||||
const HloInstructionMap<std::vector<AccumulationContext>>&
|
||||
all_reduce_to_accumulations) {
|
||||
// Note that we create all instructions before replacing and removing any old
|
||||
// instruction. This ensures that we do not accidentally access any deleted
|
||||
// instruction when creating new instructions.
|
||||
|
||||
// Step 1) create the new while init instruction, which uses zero-initialized
|
||||
// tensors as the accumulation buffers for the all-reduce.
|
||||
auto new_while_init_context =
|
||||
CreateNewWhileInit(while_instruction, all_reduce_to_accumulations);
|
||||
// Step 2) create the new while instruction.
|
||||
HloInstruction* new_while_instruction =
|
||||
while_instruction->parent()->AddInstruction(HloInstruction::CreateWhile(
|
||||
new_while_init_context.while_init->shape(),
|
||||
while_instruction->while_condition(), while_instruction->while_body(),
|
||||
new_while_init_context.while_init));
|
||||
// Step 3) create the new all-reduce instructions after the while loop.
|
||||
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_new_buffer =
|
||||
CreateSinkedAllReduces(new_while_instruction, all_reduce_to_accumulations,
|
||||
new_while_init_context.tuple_index_to_old_buffer);
|
||||
// Step 4) create the tuple and replace the old while instruction for all of
|
||||
// its uses.
|
||||
HloInstruction* new_while_result =
|
||||
CreateNewWhileResult(new_while_instruction, tuple_index_to_new_buffer);
|
||||
TF_RETURN_IF_ERROR(while_instruction->parent()->ReplaceInstruction(
|
||||
while_instruction, new_while_result));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> WhileLoopAllReduceCodeMotion::Run(HloModule* module) {
|
||||
bool is_changed = false;
|
||||
bool run_next_pass = true;
|
||||
// In case of MPMD, all-reduces might be cross-module and should preserve
|
||||
// their channel ID. Do not move all-reduces in this case since the channel
|
||||
// ID might be changed.
|
||||
if (module->config().num_partitions() > 1 &&
|
||||
!module->config().use_spmd_partitioning()) {
|
||||
return false;
|
||||
}
|
||||
// The while instruction's parent could be a while body for another while
|
||||
// loop. We recursively sink the all-reduce through nested while loops if
|
||||
// applicable by repeating this process.
|
||||
while (run_next_pass) {
|
||||
run_next_pass = false;
|
||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
||||
// A computation could be the while body of multiple while instructions,
|
||||
// so we start from the computation and find all of its callers that is a
|
||||
// kWhile if there is any.
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
std::vector<HloInstruction*> computation_callers =
|
||||
call_graph->GetComputationCallers(computation);
|
||||
std::vector<HloInstruction*> while_caller_instructions;
|
||||
for (HloInstruction* caller_instruction : computation_callers) {
|
||||
// For simplicity, we only support while instructions whose shape is
|
||||
// tuple.
|
||||
if (caller_instruction->opcode() == HloOpcode::kWhile &&
|
||||
caller_instruction->shape().IsTuple() &&
|
||||
caller_instruction->while_body() == computation) {
|
||||
while_caller_instructions.push_back(caller_instruction);
|
||||
}
|
||||
}
|
||||
// Skip to next computation if this computation is not the while body of
|
||||
// any while instruction.
|
||||
if (while_caller_instructions.empty()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<HloInstruction*> while_body_all_reduces;
|
||||
for (HloInstruction* while_body_instruction :
|
||||
computation->MakeInstructionPostOrder()) {
|
||||
if (auto* all_reduce_instruction =
|
||||
DynCast<HloAllReduceInstruction>(while_body_instruction)) {
|
||||
if (all_reduce_instruction->constrain_layout()) {
|
||||
return false;
|
||||
} else {
|
||||
while_body_all_reduces.push_back(all_reduce_instruction);
|
||||
}
|
||||
}
|
||||
}
|
||||
HloInstructionMap<std::vector<AccumulationContext>>
|
||||
all_reduce_to_accumulations;
|
||||
for (HloInstruction* all_reduce : while_body_all_reduces) {
|
||||
auto movable_all_reduce_context =
|
||||
IsAllReduceMovable(all_reduce, computation);
|
||||
if (movable_all_reduce_context.is_movable) {
|
||||
all_reduce_to_accumulations[all_reduce] =
|
||||
std::move(movable_all_reduce_context.accumulation_contexts);
|
||||
}
|
||||
}
|
||||
if (all_reduce_to_accumulations.empty()) {
|
||||
continue;
|
||||
}
|
||||
// For each while instruction calling this computation, create the
|
||||
// corresponding all-reduces after the while loop.
|
||||
for (HloInstruction* while_instruction : while_caller_instructions) {
|
||||
TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile(
|
||||
while_instruction, all_reduce_to_accumulations));
|
||||
is_changed = true;
|
||||
run_next_pass = true;
|
||||
}
|
||||
// At last, remove the old all-reduce instructions in the while body.
|
||||
for (const auto& all_reduce_accumulations_pair :
|
||||
all_reduce_to_accumulations) {
|
||||
HloInstruction* all_reduce = all_reduce_accumulations_pair.first;
|
||||
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
|
||||
all_reduce, all_reduce->mutable_operand(0)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
@ -0,0 +1,59 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// HLO pass that rewrites while loops to sink all-reduces that are only
|
||||
// accumulated into a buffer and not otherwise used in the loop body.
|
||||
// An all-reduce instruction can be sinked if its result is only added
|
||||
// to a number of accumulation buffers, and the accumulation buffers are not
|
||||
// used inside the loop.
|
||||
//
|
||||
// Pattern before this pass:
|
||||
// a = ...
|
||||
// while:
|
||||
// b = ...
|
||||
// c = all-reduce(b)
|
||||
// a += c
|
||||
// Pattern after this pass:
|
||||
// a = ...
|
||||
// d = 0
|
||||
// while:
|
||||
// b = ...
|
||||
// d += b
|
||||
// e = all-reduce(d)
|
||||
// a += e
|
||||
class WhileLoopAllReduceCodeMotion : public HloModulePass {
|
||||
public:
|
||||
explicit WhileLoopAllReduceCodeMotion() {}
|
||||
~WhileLoopAllReduceCodeMotion() override = default;
|
||||
|
||||
absl::string_view name() const override {
|
||||
static constexpr absl::string_view kName =
|
||||
"while-loop-all-reduce-code-motion";
|
||||
return kName;
|
||||
}
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
};
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
|
@ -0,0 +1,658 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
namespace op = ::xla::testing::opcode_matchers;
|
||||
using ::testing::Ne;
|
||||
using ::testing::NotNull;
|
||||
using ::testing::Property;
|
||||
using ::testing::SizeIs;
|
||||
|
||||
class WhileLoopAllReduceCodeMotionTest : public HloTestBase {};
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulate) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
|
||||
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer_init = f32[] constant(0)
|
||||
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
|
||||
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::AllReduce())));
|
||||
HloInstruction* accumulation_buffer =
|
||||
transformed_while->mutable_operand(0)->mutable_operand(3);
|
||||
EXPECT_THAT(accumulation_buffer, op::Constant());
|
||||
HloAllReduceInstruction* moved_all_reduce = DynCast<HloAllReduceInstruction>(
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) {
|
||||
return Value(instruction, op::AllReduce());
|
||||
})));
|
||||
ASSERT_THAT(moved_all_reduce, NotNull());
|
||||
EXPECT_THAT(moved_all_reduce->operand(0), op::GetTupleElement());
|
||||
EXPECT_EQ(DynCast<HloGetTupleElementInstruction>(
|
||||
moved_all_reduce->mutable_operand(0))
|
||||
->tuple_index(),
|
||||
3);
|
||||
EXPECT_THAT(moved_all_reduce->replica_groups(), SizeIs(1));
|
||||
EXPECT_TRUE(
|
||||
std::equal(moved_all_reduce->replica_groups()[0].replica_ids().begin(),
|
||||
moved_all_reduce->replica_groups()[0].replica_ids().end(),
|
||||
std::vector<int>{0, 1, 2, 3}.begin()));
|
||||
EXPECT_FALSE(moved_all_reduce->constrain_layout());
|
||||
EXPECT_TRUE(moved_all_reduce->use_global_device_ids());
|
||||
HloComputation* reduction_computation =
|
||||
module->GetComputationWithName("reduction");
|
||||
ASSERT_THAT(reduction_computation, NotNull());
|
||||
EXPECT_EQ(moved_all_reduce->called_computations()[0], reduction_computation);
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceSliceAccumulate) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[3, 1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%gte.4 = f32[1024, 1024] get-tuple-element(%param), index=4
|
||||
%gte.5 = f32[1024, 1024] get-tuple-element(%param), index=5
|
||||
%all-reduce = f32[3, 1024, 1024] all-reduce(f32[3, 1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
|
||||
%slice.0 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[0:1], [0:1024], [0:1024]}
|
||||
%reshape.0 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.0)
|
||||
%slice.1 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[1:2], [0:1024], [0:1024]}
|
||||
%reshape.1 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.1)
|
||||
%slice.2 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[2:3], [0:1024], [0:1024]}
|
||||
%reshape.2 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.2)
|
||||
%accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %reshape.0, f32[1024, 1024] %gte.3)
|
||||
%accumulation.1 = f32[1024, 1024] add(f32[1024, 1024] %reshape.1, f32[1024, 1024] %gte.4)
|
||||
%accumulation.2 = f32[1024, 1024] add(f32[1024, 1024] %reshape.2, f32[1024, 1024] %gte.5)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %accumulation.1, %accumulation.2)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[3, 1024, 1024] parameter(1)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer_init = f32[] constant(0)
|
||||
%accumulation_buffer.0 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%accumulation_buffer.1 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%accumulation_buffer.2 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%while_init = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[3, 1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, f32[1024, 1024] %accumulation_buffer.1, f32[1024, 1024] %accumulation_buffer.2)
|
||||
ROOT %while = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::AllReduce())));
|
||||
std::vector<HloInstruction*> hoisted_all_reduces;
|
||||
absl::c_copy_if(module->entry_computation()->instructions(),
|
||||
std::back_inserter(hoisted_all_reduces),
|
||||
[](HloInstruction* instruction) {
|
||||
return Value(instruction, op::AllReduce());
|
||||
});
|
||||
EXPECT_THAT(hoisted_all_reduces, SizeIs(3));
|
||||
ASSERT_THAT(
|
||||
hoisted_all_reduces,
|
||||
Each(Pointee(Property(&HloInstruction::channel_id, Ne(absl::nullopt)))));
|
||||
// Check if added all-reduces have distinct channel IDs.
|
||||
absl::flat_hash_set<int> unique_channel_ids = {
|
||||
hoisted_all_reduces[0]->channel_id().value(),
|
||||
hoisted_all_reduces[1]->channel_id().value(),
|
||||
hoisted_all_reduces[2]->channel_id().value()};
|
||||
EXPECT_THAT(unique_channel_ids, SizeIs(3));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulateUse) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
|
||||
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer_init = f32[] constant(0)
|
||||
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
|
||||
%while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
%gte_while = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
|
||||
ROOT %multiply = f32[1024, 1024] multiply(f32[1024, 1024] %gte_while, f32[1024, 1024] %param.1)
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::AllReduce())));
|
||||
HloInstruction* new_root = module->entry_computation()->root_instruction();
|
||||
ASSERT_THAT(new_root, op::Multiply());
|
||||
ASSERT_THAT(new_root->operand(0), op::GetTupleElement());
|
||||
ASSERT_THAT(new_root->operand(0)->operand(0), op::Tuple());
|
||||
EXPECT_THAT(new_root->operand(0)->operand(0)->operand(3), op::Add());
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, RepeatedlyAccumulatedAllReduce) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
|
||||
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
|
||||
%add.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %accumulation)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %add.0)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer_init = f32[] constant(0)
|
||||
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
|
||||
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
EXPECT_FALSE(simplified_loop);
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, TypeCastAllReduceAccumulate) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction {
|
||||
%x = bf16[] parameter(0)
|
||||
%y = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2)
|
||||
%all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
|
||||
%convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce)
|
||||
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer_init = f32[] constant(0)
|
||||
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
|
||||
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::AllReduce())));
|
||||
HloInstruction* accumulation_buffer =
|
||||
transformed_while->mutable_operand(0)->mutable_operand(3);
|
||||
EXPECT_THAT(accumulation_buffer, op::Constant());
|
||||
HloAllReduceInstruction* moved_all_reduce = DynCast<HloAllReduceInstruction>(
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) {
|
||||
return Value(instruction, op::AllReduce());
|
||||
})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(moved_all_reduce->shape(),
|
||||
ShapeUtil::MakeShape(BF16, {1024, 1024})));
|
||||
|
||||
HloInstruction* add_delta_to_old_buffer =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::Add());
|
||||
}));
|
||||
ASSERT_THAT(add_delta_to_old_buffer, NotNull());
|
||||
EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->shape(),
|
||||
ShapeUtil::MakeShape(F32, {1024, 1024})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(0)->shape(),
|
||||
ShapeUtil::MakeShape(F32, {1024, 1024})));
|
||||
EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(1)->shape(),
|
||||
ShapeUtil::MakeShape(F32, {1024, 1024})));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleLoopCalls) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction {
|
||||
%x = bf16[] parameter(0)
|
||||
%y = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2)
|
||||
%all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
|
||||
%convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce)
|
||||
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer_init = f32[] constant(0)
|
||||
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
|
||||
%while_init.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
|
||||
%while.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%while.0), index=3
|
||||
%while_init.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %gte.3)
|
||||
%while.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body
|
||||
ROOT %gte.4 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024])%while.1), index=3
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
|
||||
Matches(op::While())),
|
||||
2);
|
||||
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
|
||||
Matches(op::AllReduce())),
|
||||
2);
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::AllReduce())));
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleAllReduceAccumulate) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction.0 {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%reduction.1 {
|
||||
%x = bf16[] parameter(0)
|
||||
%y = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4
|
||||
%gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5
|
||||
%all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0
|
||||
%accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3)
|
||||
%all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1
|
||||
%accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %accumulation.1)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%param.2 = bf16[1024, 1024] parameter(2)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer.0 = f32[1024, 1024] constant({...})
|
||||
%accumulation_buffer.1 = bf16[1024, 1024] constant({...})
|
||||
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1)
|
||||
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
// Both all-reduces should have been sinked.
|
||||
EXPECT_THAT(transformed_while->while_body()->instructions(),
|
||||
Each(Not(op::AllReduce())));
|
||||
HloInstruction* accumulation_buffer =
|
||||
transformed_while->mutable_operand(0)->mutable_operand(3);
|
||||
EXPECT_THAT(accumulation_buffer, op::Constant());
|
||||
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
|
||||
Matches(op::AllReduce())),
|
||||
2);
|
||||
}
|
||||
|
||||
TEST_F(WhileLoopAllReduceCodeMotionTest, MixMovableAllReduceWithNotMovable) {
|
||||
constexpr absl::string_view kHloModule = R"(
|
||||
HloModule accumulated_all_reduce
|
||||
|
||||
%reduction.0 {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
%reduction.1 {
|
||||
%x = bf16[] parameter(0)
|
||||
%y = bf16[] parameter(1)
|
||||
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
|
||||
}
|
||||
|
||||
%while_condition {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
|
||||
}
|
||||
|
||||
%while_body {
|
||||
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
|
||||
%gte.0 = s32[] get-tuple-element(%param), index=0
|
||||
%gte.1 = s32[] get-tuple-element(%param), index=1
|
||||
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
|
||||
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
|
||||
%gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4
|
||||
%gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5
|
||||
%all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0
|
||||
%accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3)
|
||||
%all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1
|
||||
%accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5)
|
||||
%add.0 = bf16[1024, 1024] add(bf16[1024, 1024] %accumulation.1, bf16[1024, 1024] %gte.4)
|
||||
%constant = s32[] constant(1)
|
||||
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
|
||||
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %add.0)
|
||||
}
|
||||
|
||||
ENTRY accumulated_all_reduce {
|
||||
%param.0 = s32[] parameter(0)
|
||||
%param.1 = f32[1024, 1024] parameter(1)
|
||||
%param.2 = bf16[1024, 1024] parameter(2)
|
||||
%constant.0 = s32[] constant(1)
|
||||
%accumulation_buffer.0 = f32[1024, 1024] constant({...})
|
||||
%accumulation_buffer.1 = bf16[1024, 1024] constant({...})
|
||||
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1)
|
||||
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(kHloModule));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
|
||||
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
|
||||
ASSERT_TRUE(simplified_loop);
|
||||
TF_ASSERT_OK(
|
||||
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
|
||||
.Run(module.get())
|
||||
.status());
|
||||
HloInstruction* transformed_while =
|
||||
*(std::find_if(module->entry_computation()->instructions().begin(),
|
||||
module->entry_computation()->instructions().end(),
|
||||
[](HloInstruction* instruction) -> bool {
|
||||
return Value(instruction, op::While());
|
||||
}));
|
||||
|
||||
ASSERT_THAT(transformed_while, NotNull());
|
||||
// One all-reduce is movable and the other is not movable.
|
||||
EXPECT_EQ(absl::c_count_if(transformed_while->while_body()->instructions(),
|
||||
Matches(op::AllReduce())),
|
||||
1);
|
||||
HloInstruction* accumulation_buffer =
|
||||
transformed_while->mutable_operand(0)->mutable_operand(3);
|
||||
EXPECT_THAT(accumulation_buffer, op::Constant());
|
||||
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
|
||||
Matches(op::AllReduce())),
|
||||
1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
Loading…
x
Reference in New Issue
Block a user