Rolling forward after fixing C++ version problems.

Original public description:

Add a WhileLoopAllReduceCodeMotion pass, which finds the all-reduce-then-accumuate pattern inside while loops and moves the all-reduce to the outside of the while loop.

PiperOrigin-RevId: 349594327
Change-Id: I6f2b6c67d9b89d257585085f97a7c022298d5d1e
This commit is contained in:
Jinliang Wei 2020-12-30 13:19:02 -08:00 committed by TensorFlower Gardener
parent 64c87d65ff
commit f8975c3e43
4 changed files with 1359 additions and 0 deletions

View File

@ -4665,6 +4665,48 @@ tf_cc_test(
],
)
cc_library(
name = "while_loop_all_reduce_code_motion",
srcs = ["while_loop_all_reduce_code_motion.cc"],
hdrs = ["while_loop_all_reduce_code_motion.h"],
deps = [
":call_graph",
":hlo",
":hlo_casting_utils",
":hlo_pass",
":hlo_query",
":pattern_matcher",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/types:span",
],
)
tf_cc_test(
name = "while_loop_all_reduce_code_motion_test",
srcs = ["while_loop_all_reduce_code_motion_test.cc"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_matchers",
":hlo_verifier",
":while_loop_all_reduce_code_motion",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/strings",
],
)
cc_library(
name = "while_loop_invariant_code_motion",
srcs = ["while_loop_invariant_code_motion.cc"],

View File

@ -0,0 +1,600 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h"
#include <iterator>
#include <stack>
#include <tuple>
#include "absl/algorithm/container.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_query.h"
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/status.h"
namespace xla {
namespace {
struct AccumulationContext {
HloInstruction* accumulation_instruction;
HloInstruction* accumulation_buffer;
std::vector<int> param_tuple_indices;
};
// Describes whether an all-reduce instruction can be sinked from a while body
// computation and all the accumulation uses of the all-reduce's result in the
// while body if movable.
struct MovableAllReduceContext {
bool is_movable;
// If movable, `accumulation_contexts` contains one accumulation
// context for each accumulation in the while body that uses the all-reduce's
// result. Otherwise, this field is undefined.
std::vector<AccumulationContext> accumulation_contexts;
};
// Checks if an all-reduce instruction is eligible for sinking and finds all of
// the all-reduce's accumulation uses inside the while body if eligible.
MovableAllReduceContext IsAllReduceMovable(HloInstruction* all_reduce,
HloComputation* while_body) {
auto all_reduce_is_summation = [](HloInstruction* all_reduce) -> bool {
HloInstruction* to_apply_root = all_reduce->to_apply()->root_instruction();
if (all_reduce->to_apply()->num_parameters() != 2) {
return false;
}
return Match(to_apply_root,
match::AddAnyOrder(match::Parameter(0), match::Parameter(1)));
};
// We only support numerical types.
const absl::InlinedVector<PrimitiveType, 12> kSupportedTypes{
BF16, F16, F32, F64, S8, S16, S32, S64, U8, U16, U32, U64};
if (!absl::c_linear_search(kSupportedTypes,
all_reduce->shape().element_type()) ||
!all_reduce_is_summation(all_reduce)) {
return MovableAllReduceContext{/*is_movable=*/false,
/*accumulation_contexts=*/{}};
}
struct BufferTupleIndex {
bool unsupported_operation{false};
std::vector<int> tuple_index;
bool returned_from_computation{false};
};
// If the instruction is a buffer forwarded from a tuple element of the
// computation's parameter, returns the indices of the buffer in the parameter
// tuple. The returned_from_computation field in the result is unused.
auto get_origin_tuple_index =
[](HloInstruction* instruction) -> BufferTupleIndex {
// The returned_from_computation is never touched in this function.
BufferTupleIndex result;
while (!result.unsupported_operation) {
switch (instruction->opcode()) {
default:
result.unsupported_operation = true;
break;
case HloOpcode::kBitcast:
case HloOpcode::kConvert:
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
case HloOpcode::kDynamicReshape:
instruction = instruction->mutable_operand(0);
break;
case HloOpcode::kGetTupleElement: {
if (!result.tuple_index.empty()) {
// Note that we don't support nested tuples as of now.
result.unsupported_operation = true;
} else {
result.tuple_index.push_back(
Cast<HloGetTupleElementInstruction>(instruction)
->tuple_index());
instruction = instruction->mutable_operand(0);
}
break;
}
case HloOpcode::kParameter: {
int parameter_number =
Cast<HloParameterInstruction>(instruction)->parameter_number();
CHECK_EQ(parameter_number, 0);
break;
}
}
if (instruction->opcode() == HloOpcode::kParameter) {
break;
}
}
return result;
};
// If the instruction's result is returned from its parent computation with
// only forwarding operations, returns the index of the result buffer in the
// output parameter tuple.
auto get_output_tuple_index =
[](HloInstruction* instruction,
HloComputation* while_body) -> BufferTupleIndex {
BufferTupleIndex result;
std::stack<HloInstruction*> to_visit;
to_visit.push(instruction);
while (!to_visit.empty() && !result.unsupported_operation) {
HloInstruction* instruction = to_visit.top();
to_visit.pop();
for (HloInstruction* user : instruction->users()) {
switch (user->opcode()) {
case HloOpcode::kBitcast:
case HloOpcode::kConvert:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTranspose:
case HloOpcode::kSlice: {
to_visit.push(user);
break;
}
case HloOpcode::kDynamicSlice: {
if (user->operand_index(instruction) == 0) {
to_visit.push(user);
} else {
result.unsupported_operation = true;
}
break;
}
case HloOpcode::kTuple: {
if (!result.tuple_index.empty()) {
// Note that we don't support nested tuples as of now.
result.unsupported_operation = true;
} else {
result.tuple_index.push_back(user->operand_index(instruction));
if (while_body->root_instruction() == user) {
if (result.returned_from_computation) {
result.unsupported_operation = true;
}
result.returned_from_computation = true;
} else {
to_visit.push(user);
}
}
break;
}
default:
result.unsupported_operation = true;
}
if (result.unsupported_operation) {
break;
}
}
}
return result;
};
// Checks whether any buffer in the list of accumulation contexts is used in
// the parent computation except for forwarding uses.
auto is_buffer_used =
[](absl::Span<const AccumulationContext> accumulation_contexts,
HloComputation* while_body_computation) -> bool {
std::vector<HloInstruction*> parameter_instructions;
absl::c_copy_if(while_body_computation->instructions(),
std::back_inserter(parameter_instructions),
[](HloInstruction* instruction) -> bool {
return instruction->opcode() == HloOpcode::kParameter;
});
for (const auto& accumulation : accumulation_contexts) {
HloInstruction* accumulation_instruction =
accumulation.accumulation_instruction;
int tuple_index = accumulation.param_tuple_indices[0];
std::stack<HloInstruction*> to_visit;
// TODO(b/176437845): simplify the logic below by using
// TuplePointsToAnalysis.
for (HloInstruction* parameter_instruction : parameter_instructions) {
// Iterate over all users of the while body parameter and find all
// instructions that use the accumulation buffer, as specified by
// tuple_index.
// This logic could be simplied by using TuplePointsToAnalysis, which
// we leave to a future CL (see TODO above).
for (HloInstruction* user : parameter_instruction->users()) {
if (auto* gte = DynCast<HloGetTupleElementInstruction>(user)) {
if (gte->tuple_index() == tuple_index) {
to_visit.push(user);
}
} else {
return true;
}
}
}
while (!to_visit.empty()) {
HloInstruction* instruction = to_visit.top();
to_visit.pop();
for (HloInstruction* user : instruction->users()) {
switch (user->opcode()) {
case HloOpcode::kBitcast:
case HloOpcode::kConvert:
case HloOpcode::kReshape:
case HloOpcode::kTranspose:
to_visit.push(user);
break;
case HloOpcode::kDynamicReshape: {
if (instruction == user->operand(0)) {
to_visit.push(user);
} else {
return true;
}
break;
}
case HloOpcode::kAdd: {
if (user != accumulation_instruction) {
return true;
}
break;
}
default:
return true;
}
}
}
}
return false;
};
// Finds all accumulation contexts of the given all-reduce instruction if it
// is movable.
auto get_accumulation_contexts =
[&get_origin_tuple_index, &get_output_tuple_index, &is_buffer_used](
HloInstruction* all_reduce,
HloComputation* while_body) -> MovableAllReduceContext {
std::vector<AccumulationContext> accumulation_contexts;
// DFS starting from the all-reduce instruction and stops at the first
// non-triival uses of the all-reduce result or finds all accmululations
// of the all-reduce result.
std::stack<HloInstruction*> to_visit;
// By default movable unless we find that it's not.
bool is_all_reduce_movable = true;
to_visit.push(all_reduce);
while (!to_visit.empty() && is_all_reduce_movable) {
HloInstruction* instruction = to_visit.top();
to_visit.pop();
for (HloInstruction* user : instruction->users()) {
switch (user->opcode()) {
case HloOpcode::kBitcast:
case HloOpcode::kConvert:
case HloOpcode::kReshape:
case HloOpcode::kGetTupleElement:
case HloOpcode::kTranspose:
case HloOpcode::kSlice: {
to_visit.push(user);
break;
}
case HloOpcode::kDynamicSlice: {
if (user->operand_index(instruction) == 0) {
to_visit.push(user);
} else {
is_all_reduce_movable = false;
break;
}
break;
}
case HloOpcode::kAdd: {
int64 buffer_index = 1 - user->operand_index(instruction);
HloInstruction* accumulation_buffer =
user->mutable_operand(buffer_index);
auto origin_buffer_tuple_index =
get_origin_tuple_index(accumulation_buffer);
if (origin_buffer_tuple_index.unsupported_operation) {
is_all_reduce_movable = false;
break;
}
auto output_buffer_tuple_index =
get_output_tuple_index(user, while_body);
if (!output_buffer_tuple_index.unsupported_operation &&
output_buffer_tuple_index.returned_from_computation &&
!origin_buffer_tuple_index.tuple_index.empty() &&
ContainersEqual(origin_buffer_tuple_index.tuple_index,
output_buffer_tuple_index.tuple_index)) {
accumulation_contexts.push_back(AccumulationContext{
user, accumulation_buffer,
std::move(output_buffer_tuple_index.tuple_index)});
} else {
is_all_reduce_movable = false;
}
break;
}
default:
is_all_reduce_movable = false;
}
}
}
if (is_buffer_used(accumulation_contexts, while_body)) {
is_all_reduce_movable = false;
}
return MovableAllReduceContext{is_all_reduce_movable,
accumulation_contexts};
};
return get_accumulation_contexts(all_reduce, while_body);
}
struct WhileInitContext {
HloInstruction* while_init{nullptr};
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_old_buffer;
};
// Creates a new while init instruction, which replaces each accumulation buffer
// in the given accumulation contexts with a zero-initialized buffer. In other
// words, we are accumulating all the deltas in the while loop with a zero
// initial value.
WhileInitContext CreateNewWhileInit(
HloInstruction* old_while_instruction,
const HloInstructionMap<std::vector<AccumulationContext>>&
all_reduce_to_accumulations) {
HloInstruction* old_while_init = old_while_instruction->mutable_operand(0);
HloComputation* while_parent = old_while_instruction->parent();
std::vector<HloInstruction*> new_while_init_elements(
old_while_init->operand_count(), nullptr);
for (const auto& all_reduce_and_accumulations_pair :
all_reduce_to_accumulations) {
const std::vector<AccumulationContext>& accumulations =
all_reduce_and_accumulations_pair.second;
for (auto& accumulation_context : accumulations) {
CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1);
int tuple_index = accumulation_context.param_tuple_indices[0];
HloInstruction* old_buffer = old_while_init->mutable_operand(tuple_index);
HloInstruction* new_buffer = while_parent->AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateFromDimensions(
old_buffer->shape().element_type(),
old_buffer->shape().dimensions())));
new_while_init_elements[tuple_index] = new_buffer;
}
}
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_old_buffer;
for (int i = 0; i < old_while_init->operand_count(); i++) {
if (!new_while_init_elements[i]) {
new_while_init_elements[i] = old_while_init->mutable_operand(i);
} else {
tuple_index_to_old_buffer[i] = old_while_init->mutable_operand(i);
}
}
HloInstruction* new_while_init = while_parent->AddInstruction(
HloInstruction::CreateTuple(new_while_init_elements));
return WhileInitContext{new_while_init, tuple_index_to_old_buffer};
}
// Creates all the sinked all-reduce instructions in the while instruction's
// parent computation. Returns a map that maps a tuple index of an accumulation
// buffer to it's corresponding all-reduce.
absl::flat_hash_map<int, HloInstruction*> CreateSinkedAllReduces(
HloInstruction* new_while_instruction,
const HloInstructionMap<std::vector<AccumulationContext>>&
all_reduce_to_accumulations,
const absl::flat_hash_map<int, HloInstruction*>&
tuple_index_to_old_buffer) {
HloComputation* while_parent = new_while_instruction->parent();
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_new_buffer;
for (const auto& all_reduce_and_accumulations_pair :
all_reduce_to_accumulations) {
HloInstruction* loop_all_reduce = all_reduce_and_accumulations_pair.first;
const std::vector<AccumulationContext>& accumulations =
all_reduce_and_accumulations_pair.second;
for (const auto& accumulation_context : accumulations) {
CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1);
int tuple_index = accumulation_context.param_tuple_indices[0];
const Shape& accumulation_buffer_shape =
new_while_instruction->shape().tuple_shapes(tuple_index);
HloInstruction* accumulation_buffer =
while_parent->AddInstruction(HloInstruction::CreateGetTupleElement(
accumulation_buffer_shape, new_while_instruction, tuple_index));
HloAllReduceInstruction* old_all_reduce =
Cast<HloAllReduceInstruction>(loop_all_reduce);
HloInstruction* all_reduce_operand = accumulation_buffer;
if (!ShapeUtil::SameElementType(old_all_reduce->shape(),
accumulation_buffer_shape)) {
Shape all_reduce_shape =
ShapeUtil::MakeShape(old_all_reduce->shape().element_type(),
accumulation_buffer_shape.dimensions());
all_reduce_operand =
while_parent->AddInstruction(HloInstruction::CreateConvert(
all_reduce_shape, accumulation_buffer));
}
HloInstruction* new_all_reduce =
while_parent->AddInstruction(HloInstruction::CreateAllReduce(
all_reduce_operand->shape(), {all_reduce_operand},
old_all_reduce->called_computations()[0],
old_all_reduce->replica_groups(),
old_all_reduce->constrain_layout(),
hlo_query::NextChannelId(*(while_parent->parent())),
old_all_reduce->use_global_device_ids()));
HloInstruction* all_reduced_delta = new_all_reduce;
if (!ShapeUtil::SameElementType(all_reduced_delta->shape(),
accumulation_buffer_shape)) {
all_reduced_delta =
while_parent->AddInstruction(HloInstruction::CreateConvert(
accumulation_buffer_shape, all_reduced_delta));
}
CHECK(ContainsKey(tuple_index_to_old_buffer, tuple_index));
HloInstruction* old_buffer = tuple_index_to_old_buffer.at(tuple_index);
CHECK(ShapeUtil::Equal(old_buffer->shape(), all_reduced_delta->shape()));
HloInstruction* add_to_old_buffer =
while_parent->AddInstruction(HloInstruction::CreateBinary(
all_reduced_delta->shape(), HloOpcode::kAdd, old_buffer,
all_reduced_delta));
tuple_index_to_new_buffer[tuple_index] = add_to_old_buffer;
}
}
return tuple_index_to_new_buffer;
}
// Creates a tuple which is equivalent to the original while instruction's
// output.
HloInstruction* CreateNewWhileResult(
HloInstruction* new_while_instruction,
const absl::flat_hash_map<int, HloInstruction*>&
tuple_index_to_new_buffer) {
HloComputation* while_parent = new_while_instruction->parent();
CHECK(new_while_instruction->shape().IsTuple());
std::vector<HloInstruction*> new_while_result_elements(
new_while_instruction->shape().tuple_shapes_size(), nullptr);
for (int i = 0; i < new_while_result_elements.size(); i++) {
if (ContainsKey(tuple_index_to_new_buffer, i)) {
new_while_result_elements[i] = tuple_index_to_new_buffer.at(i);
} else {
HloInstruction* gte =
while_parent->AddInstruction(HloInstruction::CreateGetTupleElement(
new_while_instruction->shape().tuple_shapes(i),
new_while_instruction, i));
new_while_result_elements[i] = gte;
}
}
HloInstruction* new_while_result = while_parent->AddInstruction(
HloInstruction::CreateTuple(new_while_result_elements));
return new_while_result;
}
// Creates the sinked all-reduce instructions for all accumulation buffers. The
// all-reduce outputs are then added to the original accumulation buffers.
// Creates a tuple that groups the while loop output and the accumulated
// buffers and replaces all uses of the old while with this new tuple.
Status AddSinkedAllReducesAndReplaceWhile(
HloInstruction* while_instruction,
const HloInstructionMap<std::vector<AccumulationContext>>&
all_reduce_to_accumulations) {
// Note that we create all instructions before replacing and removing any old
// instruction. This ensures that we do not accidentally access any deleted
// instruction when creating new instructions.
// Step 1) create the new while init instruction, which uses zero-initialized
// tensors as the accumulation buffers for the all-reduce.
auto new_while_init_context =
CreateNewWhileInit(while_instruction, all_reduce_to_accumulations);
// Step 2) create the new while instruction.
HloInstruction* new_while_instruction =
while_instruction->parent()->AddInstruction(HloInstruction::CreateWhile(
new_while_init_context.while_init->shape(),
while_instruction->while_condition(), while_instruction->while_body(),
new_while_init_context.while_init));
// Step 3) create the new all-reduce instructions after the while loop.
absl::flat_hash_map<int, HloInstruction*> tuple_index_to_new_buffer =
CreateSinkedAllReduces(new_while_instruction, all_reduce_to_accumulations,
new_while_init_context.tuple_index_to_old_buffer);
// Step 4) create the tuple and replace the old while instruction for all of
// its uses.
HloInstruction* new_while_result =
CreateNewWhileResult(new_while_instruction, tuple_index_to_new_buffer);
TF_RETURN_IF_ERROR(while_instruction->parent()->ReplaceInstruction(
while_instruction, new_while_result));
return Status::OK();
}
} // namespace
StatusOr<bool> WhileLoopAllReduceCodeMotion::Run(HloModule* module) {
bool is_changed = false;
bool run_next_pass = true;
// In case of MPMD, all-reduces might be cross-module and should preserve
// their channel ID. Do not move all-reduces in this case since the channel
// ID might be changed.
if (module->config().num_partitions() > 1 &&
!module->config().use_spmd_partitioning()) {
return false;
}
// The while instruction's parent could be a while body for another while
// loop. We recursively sink the all-reduce through nested while loops if
// applicable by repeating this process.
while (run_next_pass) {
run_next_pass = false;
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
// A computation could be the while body of multiple while instructions,
// so we start from the computation and find all of its callers that is a
// kWhile if there is any.
for (HloComputation* computation : module->computations()) {
std::vector<HloInstruction*> computation_callers =
call_graph->GetComputationCallers(computation);
std::vector<HloInstruction*> while_caller_instructions;
for (HloInstruction* caller_instruction : computation_callers) {
// For simplicity, we only support while instructions whose shape is
// tuple.
if (caller_instruction->opcode() == HloOpcode::kWhile &&
caller_instruction->shape().IsTuple() &&
caller_instruction->while_body() == computation) {
while_caller_instructions.push_back(caller_instruction);
}
}
// Skip to next computation if this computation is not the while body of
// any while instruction.
if (while_caller_instructions.empty()) {
continue;
}
std::vector<HloInstruction*> while_body_all_reduces;
for (HloInstruction* while_body_instruction :
computation->MakeInstructionPostOrder()) {
if (auto* all_reduce_instruction =
DynCast<HloAllReduceInstruction>(while_body_instruction)) {
if (all_reduce_instruction->constrain_layout()) {
return false;
} else {
while_body_all_reduces.push_back(all_reduce_instruction);
}
}
}
HloInstructionMap<std::vector<AccumulationContext>>
all_reduce_to_accumulations;
for (HloInstruction* all_reduce : while_body_all_reduces) {
auto movable_all_reduce_context =
IsAllReduceMovable(all_reduce, computation);
if (movable_all_reduce_context.is_movable) {
all_reduce_to_accumulations[all_reduce] =
std::move(movable_all_reduce_context.accumulation_contexts);
}
}
if (all_reduce_to_accumulations.empty()) {
continue;
}
// For each while instruction calling this computation, create the
// corresponding all-reduces after the while loop.
for (HloInstruction* while_instruction : while_caller_instructions) {
TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile(
while_instruction, all_reduce_to_accumulations));
is_changed = true;
run_next_pass = true;
}
// At last, remove the old all-reduce instructions in the while body.
for (const auto& all_reduce_accumulations_pair :
all_reduce_to_accumulations) {
HloInstruction* all_reduce = all_reduce_accumulations_pair.first;
TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
all_reduce, all_reduce->mutable_operand(0)));
}
}
}
return is_changed;
}
} // namespace xla

View File

@ -0,0 +1,59 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// HLO pass that rewrites while loops to sink all-reduces that are only
// accumulated into a buffer and not otherwise used in the loop body.
// An all-reduce instruction can be sinked if its result is only added
// to a number of accumulation buffers, and the accumulation buffers are not
// used inside the loop.
//
// Pattern before this pass:
// a = ...
// while:
// b = ...
// c = all-reduce(b)
// a += c
// Pattern after this pass:
// a = ...
// d = 0
// while:
// b = ...
// d += b
// e = all-reduce(d)
// a += e
class WhileLoopAllReduceCodeMotion : public HloModulePass {
public:
explicit WhileLoopAllReduceCodeMotion() {}
~WhileLoopAllReduceCodeMotion() override = default;
absl::string_view name() const override {
static constexpr absl::string_view kName =
"while-loop-all-reduce-code-motion";
return kName;
}
StatusOr<bool> Run(HloModule* module) override;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_

View File

@ -0,0 +1,658 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h"
#include <algorithm>
#include <iterator>
#include "absl/algorithm/container.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
namespace op = ::xla::testing::opcode_matchers;
using ::testing::Ne;
using ::testing::NotNull;
using ::testing::Property;
using ::testing::SizeIs;
class WhileLoopAllReduceCodeMotionTest : public HloTestBase {};
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulate) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%constant.0 = s32[] constant(1)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::AllReduce())));
HloInstruction* accumulation_buffer =
transformed_while->mutable_operand(0)->mutable_operand(3);
EXPECT_THAT(accumulation_buffer, op::Constant());
HloAllReduceInstruction* moved_all_reduce = DynCast<HloAllReduceInstruction>(
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) {
return Value(instruction, op::AllReduce());
})));
ASSERT_THAT(moved_all_reduce, NotNull());
EXPECT_THAT(moved_all_reduce->operand(0), op::GetTupleElement());
EXPECT_EQ(DynCast<HloGetTupleElementInstruction>(
moved_all_reduce->mutable_operand(0))
->tuple_index(),
3);
EXPECT_THAT(moved_all_reduce->replica_groups(), SizeIs(1));
EXPECT_TRUE(
std::equal(moved_all_reduce->replica_groups()[0].replica_ids().begin(),
moved_all_reduce->replica_groups()[0].replica_ids().end(),
std::vector<int>{0, 1, 2, 3}.begin()));
EXPECT_FALSE(moved_all_reduce->constrain_layout());
EXPECT_TRUE(moved_all_reduce->use_global_device_ids());
HloComputation* reduction_computation =
module->GetComputationWithName("reduction");
ASSERT_THAT(reduction_computation, NotNull());
EXPECT_EQ(moved_all_reduce->called_computations()[0], reduction_computation);
}
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceSliceAccumulate) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[3, 1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%gte.4 = f32[1024, 1024] get-tuple-element(%param), index=4
%gte.5 = f32[1024, 1024] get-tuple-element(%param), index=5
%all-reduce = f32[3, 1024, 1024] all-reduce(f32[3, 1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
%slice.0 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[0:1], [0:1024], [0:1024]}
%reshape.0 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.0)
%slice.1 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[1:2], [0:1024], [0:1024]}
%reshape.1 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.1)
%slice.2 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[2:3], [0:1024], [0:1024]}
%reshape.2 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.2)
%accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %reshape.0, f32[1024, 1024] %gte.3)
%accumulation.1 = f32[1024, 1024] add(f32[1024, 1024] %reshape.1, f32[1024, 1024] %gte.4)
%accumulation.2 = f32[1024, 1024] add(f32[1024, 1024] %reshape.2, f32[1024, 1024] %gte.5)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %accumulation.1, %accumulation.2)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[3, 1024, 1024] parameter(1)
%constant.0 = s32[] constant(1)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer.0 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%accumulation_buffer.1 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%accumulation_buffer.2 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%while_init = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[3, 1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, f32[1024, 1024] %accumulation_buffer.1, f32[1024, 1024] %accumulation_buffer.2)
ROOT %while = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::AllReduce())));
std::vector<HloInstruction*> hoisted_all_reduces;
absl::c_copy_if(module->entry_computation()->instructions(),
std::back_inserter(hoisted_all_reduces),
[](HloInstruction* instruction) {
return Value(instruction, op::AllReduce());
});
EXPECT_THAT(hoisted_all_reduces, SizeIs(3));
ASSERT_THAT(
hoisted_all_reduces,
Each(Pointee(Property(&HloInstruction::channel_id, Ne(absl::nullopt)))));
// Check if added all-reduces have distinct channel IDs.
absl::flat_hash_set<int> unique_channel_ids = {
hoisted_all_reduces[0]->channel_id().value(),
hoisted_all_reduces[1]->channel_id().value(),
hoisted_all_reduces[2]->channel_id().value()};
EXPECT_THAT(unique_channel_ids, SizeIs(3));
}
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulateUse) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%constant.0 = s32[] constant(1)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
%while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
%gte_while = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
ROOT %multiply = f32[1024, 1024] multiply(f32[1024, 1024] %gte_while, f32[1024, 1024] %param.1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::AllReduce())));
HloInstruction* new_root = module->entry_computation()->root_instruction();
ASSERT_THAT(new_root, op::Multiply());
ASSERT_THAT(new_root->operand(0), op::GetTupleElement());
ASSERT_THAT(new_root->operand(0)->operand(0), op::Tuple());
EXPECT_THAT(new_root->operand(0)->operand(0)->operand(3), op::Add());
}
TEST_F(WhileLoopAllReduceCodeMotionTest, RepeatedlyAccumulatedAllReduce) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
%add.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %accumulation)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %add.0)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%constant.0 = s32[] constant(1)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
EXPECT_FALSE(simplified_loop);
}
TEST_F(WhileLoopAllReduceCodeMotionTest, TypeCastAllReduceAccumulate) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction {
%x = bf16[] parameter(0)
%y = bf16[] parameter(1)
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2)
%all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
%convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce)
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%constant.0 = s32[] constant(1)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::AllReduce())));
HloInstruction* accumulation_buffer =
transformed_while->mutable_operand(0)->mutable_operand(3);
EXPECT_THAT(accumulation_buffer, op::Constant());
HloAllReduceInstruction* moved_all_reduce = DynCast<HloAllReduceInstruction>(
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) {
return Value(instruction, op::AllReduce());
})));
EXPECT_TRUE(ShapeUtil::Equal(moved_all_reduce->shape(),
ShapeUtil::MakeShape(BF16, {1024, 1024})));
HloInstruction* add_delta_to_old_buffer =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::Add());
}));
ASSERT_THAT(add_delta_to_old_buffer, NotNull());
EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->shape(),
ShapeUtil::MakeShape(F32, {1024, 1024})));
EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(0)->shape(),
ShapeUtil::MakeShape(F32, {1024, 1024})));
EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(1)->shape(),
ShapeUtil::MakeShape(F32, {1024, 1024})));
}
TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleLoopCalls) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction {
%x = bf16[] parameter(0)
%y = bf16[] parameter(1)
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2)
%all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
%convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce)
%accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%constant.0 = s32[] constant(1)
%accumulation_buffer_init = f32[] constant(0)
%accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
%while_init.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
%while.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body
%gte.3 = f32[1024, 1024] get-tuple-element(%while.0), index=3
%while_init.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %gte.3)
%while.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body
ROOT %gte.4 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024])%while.1), index=3
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
Matches(op::While())),
2);
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
Matches(op::AllReduce())),
2);
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::AllReduce())));
}
TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleAllReduceAccumulate) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction.0 {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%reduction.1 {
%x = bf16[] parameter(0)
%y = bf16[] parameter(1)
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4
%gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5
%all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0
%accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3)
%all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1
%accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %accumulation.1)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%param.2 = bf16[1024, 1024] parameter(2)
%constant.0 = s32[] constant(1)
%accumulation_buffer.0 = f32[1024, 1024] constant({...})
%accumulation_buffer.1 = bf16[1024, 1024] constant({...})
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1)
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
// Both all-reduces should have been sinked.
EXPECT_THAT(transformed_while->while_body()->instructions(),
Each(Not(op::AllReduce())));
HloInstruction* accumulation_buffer =
transformed_while->mutable_operand(0)->mutable_operand(3);
EXPECT_THAT(accumulation_buffer, op::Constant());
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
Matches(op::AllReduce())),
2);
}
TEST_F(WhileLoopAllReduceCodeMotionTest, MixMovableAllReduceWithNotMovable) {
constexpr absl::string_view kHloModule = R"(
HloModule accumulated_all_reduce
%reduction.0 {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(f32[] %x, f32[] %y)
}
%reduction.1 {
%x = bf16[] parameter(0)
%y = bf16[] parameter(1)
ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
}
%while_condition {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
}
%while_body {
%param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
%gte.0 = s32[] get-tuple-element(%param), index=0
%gte.1 = s32[] get-tuple-element(%param), index=1
%gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
%gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
%gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4
%gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5
%all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0
%accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3)
%all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1
%accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5)
%add.0 = bf16[1024, 1024] add(bf16[1024, 1024] %accumulation.1, bf16[1024, 1024] %gte.4)
%constant = s32[] constant(1)
%increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %add.0)
}
ENTRY accumulated_all_reduce {
%param.0 = s32[] parameter(0)
%param.1 = f32[1024, 1024] parameter(1)
%param.2 = bf16[1024, 1024] parameter(2)
%constant.0 = s32[] constant(1)
%accumulation_buffer.0 = f32[1024, 1024] constant({...})
%accumulation_buffer.1 = bf16[1024, 1024] constant({...})
%while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1)
ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(kHloModule));
TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
WhileLoopAllReduceCodeMotion{}.Run(module.get()));
ASSERT_TRUE(simplified_loop);
TF_ASSERT_OK(
HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
.Run(module.get())
.status());
HloInstruction* transformed_while =
*(std::find_if(module->entry_computation()->instructions().begin(),
module->entry_computation()->instructions().end(),
[](HloInstruction* instruction) -> bool {
return Value(instruction, op::While());
}));
ASSERT_THAT(transformed_while, NotNull());
// One all-reduce is movable and the other is not movable.
EXPECT_EQ(absl::c_count_if(transformed_while->while_body()->instructions(),
Matches(op::AllReduce())),
1);
HloInstruction* accumulation_buffer =
transformed_while->mutable_operand(0)->mutable_operand(3);
EXPECT_THAT(accumulation_buffer, op::Constant());
EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
Matches(op::AllReduce())),
1);
}
} // namespace
} // namespace xla