diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 3b0fe1190f8..ec38edbee8d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -4665,6 +4665,48 @@ tf_cc_test( ], ) +cc_library( + name = "while_loop_all_reduce_code_motion", + srcs = ["while_loop_all_reduce_code_motion.cc"], + hdrs = ["while_loop_all_reduce_code_motion.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_casting_utils", + ":hlo_pass", + ":hlo_query", + ":pattern_matcher", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core/platform:status", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/types:span", + ], +) + +tf_cc_test( + name = "while_loop_all_reduce_code_motion_test", + srcs = ["while_loop_all_reduce_code_motion_test.cc"], + deps = [ + ":hlo", + ":hlo_casting_utils", + ":hlo_matchers", + ":hlo_verifier", + ":while_loop_all_reduce_code_motion", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "while_loop_invariant_code_motion", srcs = ["while_loop_invariant_code_motion.cc"], diff --git a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc new file mode 100644 index 00000000000..a52a4650c6c --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.cc @@ -0,0 +1,600 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h" + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/service/pattern_matcher.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/status.h" + +namespace xla { + +namespace { + +struct AccumulationContext { + HloInstruction* accumulation_instruction; + HloInstruction* accumulation_buffer; + std::vector param_tuple_indices; +}; + +// Describes whether an all-reduce instruction can be sinked from a while body +// computation and all the accumulation uses of the all-reduce's result in the +// while body if movable. +struct MovableAllReduceContext { + bool is_movable; + // If movable, `accumulation_contexts` contains one accumulation + // context for each accumulation in the while body that uses the all-reduce's + // result. Otherwise, this field is undefined. + std::vector accumulation_contexts; +}; + +// Checks if an all-reduce instruction is eligible for sinking and finds all of +// the all-reduce's accumulation uses inside the while body if eligible. +MovableAllReduceContext IsAllReduceMovable(HloInstruction* all_reduce, + HloComputation* while_body) { + auto all_reduce_is_summation = [](HloInstruction* all_reduce) -> bool { + HloInstruction* to_apply_root = all_reduce->to_apply()->root_instruction(); + if (all_reduce->to_apply()->num_parameters() != 2) { + return false; + } + return Match(to_apply_root, + match::AddAnyOrder(match::Parameter(0), match::Parameter(1))); + }; + + // We only support numerical types. + const absl::InlinedVector kSupportedTypes{ + BF16, F16, F32, F64, S8, S16, S32, S64, U8, U16, U32, U64}; + + if (!absl::c_linear_search(kSupportedTypes, + all_reduce->shape().element_type()) || + !all_reduce_is_summation(all_reduce)) { + return MovableAllReduceContext{/*is_movable=*/false, + /*accumulation_contexts=*/{}}; + } + + struct BufferTupleIndex { + bool unsupported_operation{false}; + std::vector tuple_index; + bool returned_from_computation{false}; + }; + + // If the instruction is a buffer forwarded from a tuple element of the + // computation's parameter, returns the indices of the buffer in the parameter + // tuple. The returned_from_computation field in the result is unused. + auto get_origin_tuple_index = + [](HloInstruction* instruction) -> BufferTupleIndex { + // The returned_from_computation is never touched in this function. + BufferTupleIndex result; + while (!result.unsupported_operation) { + switch (instruction->opcode()) { + default: + result.unsupported_operation = true; + break; + case HloOpcode::kBitcast: + case HloOpcode::kConvert: + case HloOpcode::kReshape: + case HloOpcode::kTranspose: + case HloOpcode::kDynamicReshape: + instruction = instruction->mutable_operand(0); + break; + case HloOpcode::kGetTupleElement: { + if (!result.tuple_index.empty()) { + // Note that we don't support nested tuples as of now. + result.unsupported_operation = true; + } else { + result.tuple_index.push_back( + Cast(instruction) + ->tuple_index()); + instruction = instruction->mutable_operand(0); + } + break; + } + case HloOpcode::kParameter: { + int parameter_number = + Cast(instruction)->parameter_number(); + CHECK_EQ(parameter_number, 0); + break; + } + } + if (instruction->opcode() == HloOpcode::kParameter) { + break; + } + } + return result; + }; + + // If the instruction's result is returned from its parent computation with + // only forwarding operations, returns the index of the result buffer in the + // output parameter tuple. + auto get_output_tuple_index = + [](HloInstruction* instruction, + HloComputation* while_body) -> BufferTupleIndex { + BufferTupleIndex result; + std::stack to_visit; + to_visit.push(instruction); + while (!to_visit.empty() && !result.unsupported_operation) { + HloInstruction* instruction = to_visit.top(); + to_visit.pop(); + for (HloInstruction* user : instruction->users()) { + switch (user->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kConvert: + case HloOpcode::kReshape: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: { + to_visit.push(user); + break; + } + case HloOpcode::kDynamicSlice: { + if (user->operand_index(instruction) == 0) { + to_visit.push(user); + } else { + result.unsupported_operation = true; + } + break; + } + case HloOpcode::kTuple: { + if (!result.tuple_index.empty()) { + // Note that we don't support nested tuples as of now. + result.unsupported_operation = true; + } else { + result.tuple_index.push_back(user->operand_index(instruction)); + if (while_body->root_instruction() == user) { + if (result.returned_from_computation) { + result.unsupported_operation = true; + } + result.returned_from_computation = true; + } else { + to_visit.push(user); + } + } + break; + } + default: + result.unsupported_operation = true; + } + if (result.unsupported_operation) { + break; + } + } + } + return result; + }; + + // Checks whether any buffer in the list of accumulation contexts is used in + // the parent computation except for forwarding uses. + auto is_buffer_used = + [](absl::Span accumulation_contexts, + HloComputation* while_body_computation) -> bool { + std::vector parameter_instructions; + absl::c_copy_if(while_body_computation->instructions(), + std::back_inserter(parameter_instructions), + [](HloInstruction* instruction) -> bool { + return instruction->opcode() == HloOpcode::kParameter; + }); + for (const auto& accumulation : accumulation_contexts) { + HloInstruction* accumulation_instruction = + accumulation.accumulation_instruction; + int tuple_index = accumulation.param_tuple_indices[0]; + std::stack to_visit; + // TODO(b/176437845): simplify the logic below by using + // TuplePointsToAnalysis. + for (HloInstruction* parameter_instruction : parameter_instructions) { + // Iterate over all users of the while body parameter and find all + // instructions that use the accumulation buffer, as specified by + // tuple_index. + // This logic could be simplied by using TuplePointsToAnalysis, which + // we leave to a future CL (see TODO above). + for (HloInstruction* user : parameter_instruction->users()) { + if (auto* gte = DynCast(user)) { + if (gte->tuple_index() == tuple_index) { + to_visit.push(user); + } + } else { + return true; + } + } + } + + while (!to_visit.empty()) { + HloInstruction* instruction = to_visit.top(); + to_visit.pop(); + for (HloInstruction* user : instruction->users()) { + switch (user->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kConvert: + case HloOpcode::kReshape: + case HloOpcode::kTranspose: + to_visit.push(user); + break; + case HloOpcode::kDynamicReshape: { + if (instruction == user->operand(0)) { + to_visit.push(user); + } else { + return true; + } + break; + } + case HloOpcode::kAdd: { + if (user != accumulation_instruction) { + return true; + } + break; + } + default: + return true; + } + } + } + } + return false; + }; + + // Finds all accumulation contexts of the given all-reduce instruction if it + // is movable. + auto get_accumulation_contexts = + [&get_origin_tuple_index, &get_output_tuple_index, &is_buffer_used]( + HloInstruction* all_reduce, + HloComputation* while_body) -> MovableAllReduceContext { + std::vector accumulation_contexts; + // DFS starting from the all-reduce instruction and stops at the first + // non-triival uses of the all-reduce result or finds all accmululations + // of the all-reduce result. + std::stack to_visit; + // By default movable unless we find that it's not. + bool is_all_reduce_movable = true; + to_visit.push(all_reduce); + + while (!to_visit.empty() && is_all_reduce_movable) { + HloInstruction* instruction = to_visit.top(); + to_visit.pop(); + for (HloInstruction* user : instruction->users()) { + switch (user->opcode()) { + case HloOpcode::kBitcast: + case HloOpcode::kConvert: + case HloOpcode::kReshape: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTranspose: + case HloOpcode::kSlice: { + to_visit.push(user); + break; + } + case HloOpcode::kDynamicSlice: { + if (user->operand_index(instruction) == 0) { + to_visit.push(user); + } else { + is_all_reduce_movable = false; + break; + } + break; + } + case HloOpcode::kAdd: { + int64 buffer_index = 1 - user->operand_index(instruction); + HloInstruction* accumulation_buffer = + user->mutable_operand(buffer_index); + + auto origin_buffer_tuple_index = + get_origin_tuple_index(accumulation_buffer); + if (origin_buffer_tuple_index.unsupported_operation) { + is_all_reduce_movable = false; + break; + } + + auto output_buffer_tuple_index = + get_output_tuple_index(user, while_body); + if (!output_buffer_tuple_index.unsupported_operation && + output_buffer_tuple_index.returned_from_computation && + !origin_buffer_tuple_index.tuple_index.empty() && + ContainersEqual(origin_buffer_tuple_index.tuple_index, + output_buffer_tuple_index.tuple_index)) { + accumulation_contexts.push_back(AccumulationContext{ + user, accumulation_buffer, + std::move(output_buffer_tuple_index.tuple_index)}); + } else { + is_all_reduce_movable = false; + } + break; + } + default: + is_all_reduce_movable = false; + } + } + } + if (is_buffer_used(accumulation_contexts, while_body)) { + is_all_reduce_movable = false; + } + return MovableAllReduceContext{is_all_reduce_movable, + accumulation_contexts}; + }; + return get_accumulation_contexts(all_reduce, while_body); +} + +struct WhileInitContext { + HloInstruction* while_init{nullptr}; + absl::flat_hash_map tuple_index_to_old_buffer; +}; + +// Creates a new while init instruction, which replaces each accumulation buffer +// in the given accumulation contexts with a zero-initialized buffer. In other +// words, we are accumulating all the deltas in the while loop with a zero +// initial value. +WhileInitContext CreateNewWhileInit( + HloInstruction* old_while_instruction, + const HloInstructionMap>& + all_reduce_to_accumulations) { + HloInstruction* old_while_init = old_while_instruction->mutable_operand(0); + HloComputation* while_parent = old_while_instruction->parent(); + std::vector new_while_init_elements( + old_while_init->operand_count(), nullptr); + for (const auto& all_reduce_and_accumulations_pair : + all_reduce_to_accumulations) { + const std::vector& accumulations = + all_reduce_and_accumulations_pair.second; + for (auto& accumulation_context : accumulations) { + CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1); + int tuple_index = accumulation_context.param_tuple_indices[0]; + HloInstruction* old_buffer = old_while_init->mutable_operand(tuple_index); + HloInstruction* new_buffer = while_parent->AddInstruction( + HloInstruction::CreateConstant(LiteralUtil::CreateFromDimensions( + old_buffer->shape().element_type(), + old_buffer->shape().dimensions()))); + new_while_init_elements[tuple_index] = new_buffer; + } + } + absl::flat_hash_map tuple_index_to_old_buffer; + for (int i = 0; i < old_while_init->operand_count(); i++) { + if (!new_while_init_elements[i]) { + new_while_init_elements[i] = old_while_init->mutable_operand(i); + } else { + tuple_index_to_old_buffer[i] = old_while_init->mutable_operand(i); + } + } + HloInstruction* new_while_init = while_parent->AddInstruction( + HloInstruction::CreateTuple(new_while_init_elements)); + return WhileInitContext{new_while_init, tuple_index_to_old_buffer}; +} + +// Creates all the sinked all-reduce instructions in the while instruction's +// parent computation. Returns a map that maps a tuple index of an accumulation +// buffer to it's corresponding all-reduce. +absl::flat_hash_map CreateSinkedAllReduces( + HloInstruction* new_while_instruction, + const HloInstructionMap>& + all_reduce_to_accumulations, + const absl::flat_hash_map& + tuple_index_to_old_buffer) { + HloComputation* while_parent = new_while_instruction->parent(); + absl::flat_hash_map tuple_index_to_new_buffer; + for (const auto& all_reduce_and_accumulations_pair : + all_reduce_to_accumulations) { + HloInstruction* loop_all_reduce = all_reduce_and_accumulations_pair.first; + const std::vector& accumulations = + all_reduce_and_accumulations_pair.second; + for (const auto& accumulation_context : accumulations) { + CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1); + int tuple_index = accumulation_context.param_tuple_indices[0]; + const Shape& accumulation_buffer_shape = + new_while_instruction->shape().tuple_shapes(tuple_index); + HloInstruction* accumulation_buffer = + while_parent->AddInstruction(HloInstruction::CreateGetTupleElement( + accumulation_buffer_shape, new_while_instruction, tuple_index)); + HloAllReduceInstruction* old_all_reduce = + Cast(loop_all_reduce); + HloInstruction* all_reduce_operand = accumulation_buffer; + if (!ShapeUtil::SameElementType(old_all_reduce->shape(), + accumulation_buffer_shape)) { + Shape all_reduce_shape = + ShapeUtil::MakeShape(old_all_reduce->shape().element_type(), + accumulation_buffer_shape.dimensions()); + all_reduce_operand = + while_parent->AddInstruction(HloInstruction::CreateConvert( + all_reduce_shape, accumulation_buffer)); + } + HloInstruction* new_all_reduce = + while_parent->AddInstruction(HloInstruction::CreateAllReduce( + all_reduce_operand->shape(), {all_reduce_operand}, + old_all_reduce->called_computations()[0], + old_all_reduce->replica_groups(), + old_all_reduce->constrain_layout(), + hlo_query::NextChannelId(*(while_parent->parent())), + old_all_reduce->use_global_device_ids())); + HloInstruction* all_reduced_delta = new_all_reduce; + if (!ShapeUtil::SameElementType(all_reduced_delta->shape(), + accumulation_buffer_shape)) { + all_reduced_delta = + while_parent->AddInstruction(HloInstruction::CreateConvert( + accumulation_buffer_shape, all_reduced_delta)); + } + CHECK(ContainsKey(tuple_index_to_old_buffer, tuple_index)); + HloInstruction* old_buffer = tuple_index_to_old_buffer.at(tuple_index); + CHECK(ShapeUtil::Equal(old_buffer->shape(), all_reduced_delta->shape())); + HloInstruction* add_to_old_buffer = + while_parent->AddInstruction(HloInstruction::CreateBinary( + all_reduced_delta->shape(), HloOpcode::kAdd, old_buffer, + all_reduced_delta)); + tuple_index_to_new_buffer[tuple_index] = add_to_old_buffer; + } + } + return tuple_index_to_new_buffer; +} + +// Creates a tuple which is equivalent to the original while instruction's +// output. +HloInstruction* CreateNewWhileResult( + HloInstruction* new_while_instruction, + const absl::flat_hash_map& + tuple_index_to_new_buffer) { + HloComputation* while_parent = new_while_instruction->parent(); + CHECK(new_while_instruction->shape().IsTuple()); + std::vector new_while_result_elements( + new_while_instruction->shape().tuple_shapes_size(), nullptr); + for (int i = 0; i < new_while_result_elements.size(); i++) { + if (ContainsKey(tuple_index_to_new_buffer, i)) { + new_while_result_elements[i] = tuple_index_to_new_buffer.at(i); + } else { + HloInstruction* gte = + while_parent->AddInstruction(HloInstruction::CreateGetTupleElement( + new_while_instruction->shape().tuple_shapes(i), + new_while_instruction, i)); + new_while_result_elements[i] = gte; + } + } + HloInstruction* new_while_result = while_parent->AddInstruction( + HloInstruction::CreateTuple(new_while_result_elements)); + return new_while_result; +} + +// Creates the sinked all-reduce instructions for all accumulation buffers. The +// all-reduce outputs are then added to the original accumulation buffers. +// Creates a tuple that groups the while loop output and the accumulated +// buffers and replaces all uses of the old while with this new tuple. +Status AddSinkedAllReducesAndReplaceWhile( + HloInstruction* while_instruction, + const HloInstructionMap>& + all_reduce_to_accumulations) { + // Note that we create all instructions before replacing and removing any old + // instruction. This ensures that we do not accidentally access any deleted + // instruction when creating new instructions. + + // Step 1) create the new while init instruction, which uses zero-initialized + // tensors as the accumulation buffers for the all-reduce. + auto new_while_init_context = + CreateNewWhileInit(while_instruction, all_reduce_to_accumulations); + // Step 2) create the new while instruction. + HloInstruction* new_while_instruction = + while_instruction->parent()->AddInstruction(HloInstruction::CreateWhile( + new_while_init_context.while_init->shape(), + while_instruction->while_condition(), while_instruction->while_body(), + new_while_init_context.while_init)); + // Step 3) create the new all-reduce instructions after the while loop. + absl::flat_hash_map tuple_index_to_new_buffer = + CreateSinkedAllReduces(new_while_instruction, all_reduce_to_accumulations, + new_while_init_context.tuple_index_to_old_buffer); + // Step 4) create the tuple and replace the old while instruction for all of + // its uses. + HloInstruction* new_while_result = + CreateNewWhileResult(new_while_instruction, tuple_index_to_new_buffer); + TF_RETURN_IF_ERROR(while_instruction->parent()->ReplaceInstruction( + while_instruction, new_while_result)); + return Status::OK(); +} + +} // namespace + +StatusOr WhileLoopAllReduceCodeMotion::Run(HloModule* module) { + bool is_changed = false; + bool run_next_pass = true; + // In case of MPMD, all-reduces might be cross-module and should preserve + // their channel ID. Do not move all-reduces in this case since the channel + // ID might be changed. + if (module->config().num_partitions() > 1 && + !module->config().use_spmd_partitioning()) { + return false; + } + // The while instruction's parent could be a while body for another while + // loop. We recursively sink the all-reduce through nested while loops if + // applicable by repeating this process. + while (run_next_pass) { + run_next_pass = false; + std::unique_ptr call_graph = CallGraph::Build(module); + // A computation could be the while body of multiple while instructions, + // so we start from the computation and find all of its callers that is a + // kWhile if there is any. + for (HloComputation* computation : module->computations()) { + std::vector computation_callers = + call_graph->GetComputationCallers(computation); + std::vector while_caller_instructions; + for (HloInstruction* caller_instruction : computation_callers) { + // For simplicity, we only support while instructions whose shape is + // tuple. + if (caller_instruction->opcode() == HloOpcode::kWhile && + caller_instruction->shape().IsTuple() && + caller_instruction->while_body() == computation) { + while_caller_instructions.push_back(caller_instruction); + } + } + // Skip to next computation if this computation is not the while body of + // any while instruction. + if (while_caller_instructions.empty()) { + continue; + } + std::vector while_body_all_reduces; + for (HloInstruction* while_body_instruction : + computation->MakeInstructionPostOrder()) { + if (auto* all_reduce_instruction = + DynCast(while_body_instruction)) { + if (all_reduce_instruction->constrain_layout()) { + return false; + } else { + while_body_all_reduces.push_back(all_reduce_instruction); + } + } + } + HloInstructionMap> + all_reduce_to_accumulations; + for (HloInstruction* all_reduce : while_body_all_reduces) { + auto movable_all_reduce_context = + IsAllReduceMovable(all_reduce, computation); + if (movable_all_reduce_context.is_movable) { + all_reduce_to_accumulations[all_reduce] = + std::move(movable_all_reduce_context.accumulation_contexts); + } + } + if (all_reduce_to_accumulations.empty()) { + continue; + } + // For each while instruction calling this computation, create the + // corresponding all-reduces after the while loop. + for (HloInstruction* while_instruction : while_caller_instructions) { + TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile( + while_instruction, all_reduce_to_accumulations)); + is_changed = true; + run_next_pass = true; + } + // At last, remove the old all-reduce instructions in the while body. + for (const auto& all_reduce_accumulations_pair : + all_reduce_to_accumulations) { + HloInstruction* all_reduce = all_reduce_accumulations_pair.first; + TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + all_reduce, all_reduce->mutable_operand(0))); + } + } + } + return is_changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h new file mode 100644 index 00000000000..824740a5ecc --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h @@ -0,0 +1,59 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that rewrites while loops to sink all-reduces that are only +// accumulated into a buffer and not otherwise used in the loop body. +// An all-reduce instruction can be sinked if its result is only added +// to a number of accumulation buffers, and the accumulation buffers are not +// used inside the loop. +// +// Pattern before this pass: +// a = ... +// while: +// b = ... +// c = all-reduce(b) +// a += c +// Pattern after this pass: +// a = ... +// d = 0 +// while: +// b = ... +// d += b +// e = all-reduce(d) +// a += e +class WhileLoopAllReduceCodeMotion : public HloModulePass { + public: + explicit WhileLoopAllReduceCodeMotion() {} + ~WhileLoopAllReduceCodeMotion() override = default; + + absl::string_view name() const override { + static constexpr absl::string_view kName = + "while-loop-all-reduce-code-motion"; + return kName; + } + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion_test.cc new file mode 100644 index 00000000000..a7e49293802 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion_test.cc @@ -0,0 +1,658 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; +using ::testing::Ne; +using ::testing::NotNull; +using ::testing::Property; +using ::testing::SizeIs; + +class WhileLoopAllReduceCodeMotionTest : public HloTestBase {}; + +TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + HloAllReduceInstruction* moved_all_reduce = DynCast( + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) { + return Value(instruction, op::AllReduce()); + }))); + ASSERT_THAT(moved_all_reduce, NotNull()); + EXPECT_THAT(moved_all_reduce->operand(0), op::GetTupleElement()); + EXPECT_EQ(DynCast( + moved_all_reduce->mutable_operand(0)) + ->tuple_index(), + 3); + EXPECT_THAT(moved_all_reduce->replica_groups(), SizeIs(1)); + EXPECT_TRUE( + std::equal(moved_all_reduce->replica_groups()[0].replica_ids().begin(), + moved_all_reduce->replica_groups()[0].replica_ids().end(), + std::vector{0, 1, 2, 3}.begin())); + EXPECT_FALSE(moved_all_reduce->constrain_layout()); + EXPECT_TRUE(moved_all_reduce->use_global_device_ids()); + HloComputation* reduction_computation = + module->GetComputationWithName("reduction"); + ASSERT_THAT(reduction_computation, NotNull()); + EXPECT_EQ(moved_all_reduce->called_computations()[0], reduction_computation); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceSliceAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[3, 1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %gte.4 = f32[1024, 1024] get-tuple-element(%param), index=4 + %gte.5 = f32[1024, 1024] get-tuple-element(%param), index=5 + %all-reduce = f32[3, 1024, 1024] all-reduce(f32[3, 1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %slice.0 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[0:1], [0:1024], [0:1024]} + %reshape.0 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.0) + %slice.1 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[1:2], [0:1024], [0:1024]} + %reshape.1 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.1) + %slice.2 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[2:3], [0:1024], [0:1024]} + %reshape.2 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.2) + %accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %reshape.0, f32[1024, 1024] %gte.3) + %accumulation.1 = f32[1024, 1024] add(f32[1024, 1024] %reshape.1, f32[1024, 1024] %gte.4) + %accumulation.2 = f32[1024, 1024] add(f32[1024, 1024] %reshape.2, f32[1024, 1024] %gte.5) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %accumulation.1, %accumulation.2) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[3, 1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer.0 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %accumulation_buffer.1 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %accumulation_buffer.2 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[3, 1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, f32[1024, 1024] %accumulation_buffer.1, f32[1024, 1024] %accumulation_buffer.2) + ROOT %while = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); + std::vector hoisted_all_reduces; + absl::c_copy_if(module->entry_computation()->instructions(), + std::back_inserter(hoisted_all_reduces), + [](HloInstruction* instruction) { + return Value(instruction, op::AllReduce()); + }); + EXPECT_THAT(hoisted_all_reduces, SizeIs(3)); + ASSERT_THAT( + hoisted_all_reduces, + Each(Pointee(Property(&HloInstruction::channel_id, Ne(absl::nullopt))))); + // Check if added all-reduces have distinct channel IDs. + absl::flat_hash_set unique_channel_ids = { + hoisted_all_reduces[0]->channel_id().value(), + hoisted_all_reduces[1]->channel_id().value(), + hoisted_all_reduces[2]->channel_id().value()}; + EXPECT_THAT(unique_channel_ids, SizeIs(3)); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulateUse) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + %gte_while = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3 + ROOT %multiply = f32[1024, 1024] multiply(f32[1024, 1024] %gte_while, f32[1024, 1024] %param.1) + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); + HloInstruction* new_root = module->entry_computation()->root_instruction(); + ASSERT_THAT(new_root, op::Multiply()); + ASSERT_THAT(new_root->operand(0), op::GetTupleElement()); + ASSERT_THAT(new_root->operand(0)->operand(0), op::Tuple()); + EXPECT_THAT(new_root->operand(0)->operand(0)->operand(3), op::Add()); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, RepeatedlyAccumulatedAllReduce) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3) + %add.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %accumulation) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %add.0) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, TypeCastAllReduceAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = bf16[] parameter(0) + %y = bf16[] parameter(1) + ROOT %add = bf16[] add(bf16[] %x, bf16[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2) + %all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce) + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + HloAllReduceInstruction* moved_all_reduce = DynCast( + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) { + return Value(instruction, op::AllReduce()); + }))); + EXPECT_TRUE(ShapeUtil::Equal(moved_all_reduce->shape(), + ShapeUtil::MakeShape(BF16, {1024, 1024}))); + + HloInstruction* add_delta_to_old_buffer = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::Add()); + })); + ASSERT_THAT(add_delta_to_old_buffer, NotNull()); + EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->shape(), + ShapeUtil::MakeShape(F32, {1024, 1024}))); + EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(0)->shape(), + ShapeUtil::MakeShape(F32, {1024, 1024}))); + EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(1)->shape(), + ShapeUtil::MakeShape(F32, {1024, 1024}))); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleLoopCalls) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction { + %x = bf16[] parameter(0) + %y = bf16[] parameter(1) + ROOT %add = bf16[] add(bf16[] %x, bf16[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2) + %all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction + %convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce) + %accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %constant.0 = s32[] constant(1) + %accumulation_buffer_init = f32[] constant(0) + %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={} + %while_init.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer) + %while.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body + %gte.3 = f32[1024, 1024] get-tuple-element(%while.0), index=3 + %while_init.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %gte.3) + %while.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body + ROOT %gte.4 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024])%while.1), index=3 + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + Matches(op::While())), + 2); + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + Matches(op::AllReduce())), + 2); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleAllReduceAccumulate) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction.0 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %reduction.1 { + %x = bf16[] parameter(0) + %y = bf16[] parameter(1) + ROOT %add = bf16[] add(bf16[] %x, bf16[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4 + %gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5 + %all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0 + %accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3) + %all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1 + %accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %accumulation.1) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %param.2 = bf16[1024, 1024] parameter(2) + %constant.0 = s32[] constant(1) + %accumulation_buffer.0 = f32[1024, 1024] constant({...}) + %accumulation_buffer.1 = bf16[1024, 1024] constant({...}) + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1) + ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + // Both all-reduces should have been sinked. + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::AllReduce()))); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + Matches(op::AllReduce())), + 2); +} + +TEST_F(WhileLoopAllReduceCodeMotionTest, MixMovableAllReduceWithNotMovable) { + constexpr absl::string_view kHloModule = R"( + HloModule accumulated_all_reduce + + %reduction.0 { + %x = f32[] parameter(0) + %y = f32[] parameter(1) + ROOT %add = f32[] add(f32[] %x, f32[] %y) + } + + %reduction.1 { + %x = bf16[] parameter(0) + %y = bf16[] parameter(1) + ROOT %add = bf16[] add(bf16[] %x, bf16[] %y) + } + + %while_condition { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT + } + + %while_body { + %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0) + %gte.0 = s32[] get-tuple-element(%param), index=0 + %gte.1 = s32[] get-tuple-element(%param), index=1 + %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2 + %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3 + %gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4 + %gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5 + %all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0 + %accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3) + %all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1 + %accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5) + %add.0 = bf16[1024, 1024] add(bf16[1024, 1024] %accumulation.1, bf16[1024, 1024] %gte.4) + %constant = s32[] constant(1) + %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant) + ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %add.0) + } + + ENTRY accumulated_all_reduce { + %param.0 = s32[] parameter(0) + %param.1 = f32[1024, 1024] parameter(1) + %param.2 = bf16[1024, 1024] parameter(2) + %constant.0 = s32[] constant(1) + %accumulation_buffer.0 = f32[1024, 1024] constant({...}) + %accumulation_buffer.1 = bf16[1024, 1024] constant({...}) + %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1) + ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body + } + )"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(kHloModule)); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopAllReduceCodeMotion{}.Run(module.get())); + ASSERT_TRUE(simplified_loop); + TF_ASSERT_OK( + HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true) + .Run(module.get()) + .status()); + HloInstruction* transformed_while = + *(std::find_if(module->entry_computation()->instructions().begin(), + module->entry_computation()->instructions().end(), + [](HloInstruction* instruction) -> bool { + return Value(instruction, op::While()); + })); + + ASSERT_THAT(transformed_while, NotNull()); + // One all-reduce is movable and the other is not movable. + EXPECT_EQ(absl::c_count_if(transformed_while->while_body()->instructions(), + Matches(op::AllReduce())), + 1); + HloInstruction* accumulation_buffer = + transformed_while->mutable_operand(0)->mutable_operand(3); + EXPECT_THAT(accumulation_buffer, op::Constant()); + EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(), + Matches(op::AllReduce())), + 1); +} + +} // namespace +} // namespace xla