Rolling forward after fixing C++ version problems.
Original public description: Add a WhileLoopAllReduceCodeMotion pass, which finds the all-reduce-then-accumuate pattern inside while loops and moves the all-reduce to the outside of the while loop. PiperOrigin-RevId: 349594327 Change-Id: I6f2b6c67d9b89d257585085f97a7c022298d5d1e
This commit is contained in:
		
							parent
							
								
									64c87d65ff
								
							
						
					
					
						commit
						f8975c3e43
					
				@ -4665,6 +4665,48 @@ tf_cc_test(
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "while_loop_all_reduce_code_motion",
 | 
			
		||||
    srcs = ["while_loop_all_reduce_code_motion.cc"],
 | 
			
		||||
    hdrs = ["while_loop_all_reduce_code_motion.h"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":call_graph",
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        ":hlo_casting_utils",
 | 
			
		||||
        ":hlo_pass",
 | 
			
		||||
        ":hlo_query",
 | 
			
		||||
        ":pattern_matcher",
 | 
			
		||||
        "//tensorflow/compiler/xla:literal_util",
 | 
			
		||||
        "//tensorflow/compiler/xla:status",
 | 
			
		||||
        "//tensorflow/compiler/xla:status_macros",
 | 
			
		||||
        "//tensorflow/compiler/xla:statusor",
 | 
			
		||||
        "//tensorflow/compiler/xla:util",
 | 
			
		||||
        "//tensorflow/compiler/xla:xla_data_proto_cc",
 | 
			
		||||
        "//tensorflow/core:lib_proto_parsing",
 | 
			
		||||
        "//tensorflow/core/platform:status",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/types:span",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
tf_cc_test(
 | 
			
		||||
    name = "while_loop_all_reduce_code_motion_test",
 | 
			
		||||
    srcs = ["while_loop_all_reduce_code_motion_test.cc"],
 | 
			
		||||
    deps = [
 | 
			
		||||
        ":hlo",
 | 
			
		||||
        ":hlo_casting_utils",
 | 
			
		||||
        ":hlo_matchers",
 | 
			
		||||
        ":hlo_verifier",
 | 
			
		||||
        ":while_loop_all_reduce_code_motion",
 | 
			
		||||
        "//tensorflow/compiler/xla:xla_data_proto_cc",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:hlo_test_base",
 | 
			
		||||
        "//tensorflow/compiler/xla/tests:xla_internal_test_main",
 | 
			
		||||
        "//tensorflow/core:test",
 | 
			
		||||
        "@com_google_absl//absl/algorithm:container",
 | 
			
		||||
        "@com_google_absl//absl/strings",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
cc_library(
 | 
			
		||||
    name = "while_loop_invariant_code_motion",
 | 
			
		||||
    srcs = ["while_loop_invariant_code_motion.cc"],
 | 
			
		||||
 | 
			
		||||
@ -0,0 +1,600 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h"
 | 
			
		||||
 | 
			
		||||
#include <iterator>
 | 
			
		||||
#include <stack>
 | 
			
		||||
#include <tuple>
 | 
			
		||||
 | 
			
		||||
#include "absl/algorithm/container.h"
 | 
			
		||||
#include "absl/types/span.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/literal_util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/map_util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/call_graph.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_query.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/pattern_matcher.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/status.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/status_macros.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/statusor.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/util.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status.h"
 | 
			
		||||
#include "tensorflow/core/platform/status.h"
 | 
			
		||||
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
struct AccumulationContext {
 | 
			
		||||
  HloInstruction* accumulation_instruction;
 | 
			
		||||
  HloInstruction* accumulation_buffer;
 | 
			
		||||
  std::vector<int> param_tuple_indices;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Describes whether an all-reduce instruction can be sinked from a while body
 | 
			
		||||
// computation and all the accumulation uses of the all-reduce's result in the
 | 
			
		||||
// while body if movable.
 | 
			
		||||
struct MovableAllReduceContext {
 | 
			
		||||
  bool is_movable;
 | 
			
		||||
  // If movable, `accumulation_contexts` contains one accumulation
 | 
			
		||||
  // context for each accumulation in the while body that uses the all-reduce's
 | 
			
		||||
  // result. Otherwise, this field is undefined.
 | 
			
		||||
  std::vector<AccumulationContext> accumulation_contexts;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Checks if an all-reduce instruction is eligible for sinking and finds all of
 | 
			
		||||
// the all-reduce's accumulation uses inside the while body if eligible.
 | 
			
		||||
MovableAllReduceContext IsAllReduceMovable(HloInstruction* all_reduce,
 | 
			
		||||
                                           HloComputation* while_body) {
 | 
			
		||||
  auto all_reduce_is_summation = [](HloInstruction* all_reduce) -> bool {
 | 
			
		||||
    HloInstruction* to_apply_root = all_reduce->to_apply()->root_instruction();
 | 
			
		||||
    if (all_reduce->to_apply()->num_parameters() != 2) {
 | 
			
		||||
      return false;
 | 
			
		||||
    }
 | 
			
		||||
    return Match(to_apply_root,
 | 
			
		||||
                 match::AddAnyOrder(match::Parameter(0), match::Parameter(1)));
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // We only support numerical types.
 | 
			
		||||
  const absl::InlinedVector<PrimitiveType, 12> kSupportedTypes{
 | 
			
		||||
      BF16, F16, F32, F64, S8, S16, S32, S64, U8, U16, U32, U64};
 | 
			
		||||
 | 
			
		||||
  if (!absl::c_linear_search(kSupportedTypes,
 | 
			
		||||
                             all_reduce->shape().element_type()) ||
 | 
			
		||||
      !all_reduce_is_summation(all_reduce)) {
 | 
			
		||||
    return MovableAllReduceContext{/*is_movable=*/false,
 | 
			
		||||
                                   /*accumulation_contexts=*/{}};
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  struct BufferTupleIndex {
 | 
			
		||||
    bool unsupported_operation{false};
 | 
			
		||||
    std::vector<int> tuple_index;
 | 
			
		||||
    bool returned_from_computation{false};
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // If the instruction is a buffer forwarded from a tuple element of the
 | 
			
		||||
  // computation's parameter, returns the indices of the buffer in the parameter
 | 
			
		||||
  // tuple. The returned_from_computation field in the result is unused.
 | 
			
		||||
  auto get_origin_tuple_index =
 | 
			
		||||
      [](HloInstruction* instruction) -> BufferTupleIndex {
 | 
			
		||||
    // The returned_from_computation is never touched in this function.
 | 
			
		||||
    BufferTupleIndex result;
 | 
			
		||||
    while (!result.unsupported_operation) {
 | 
			
		||||
      switch (instruction->opcode()) {
 | 
			
		||||
        default:
 | 
			
		||||
          result.unsupported_operation = true;
 | 
			
		||||
          break;
 | 
			
		||||
        case HloOpcode::kBitcast:
 | 
			
		||||
        case HloOpcode::kConvert:
 | 
			
		||||
        case HloOpcode::kReshape:
 | 
			
		||||
        case HloOpcode::kTranspose:
 | 
			
		||||
        case HloOpcode::kDynamicReshape:
 | 
			
		||||
          instruction = instruction->mutable_operand(0);
 | 
			
		||||
          break;
 | 
			
		||||
        case HloOpcode::kGetTupleElement: {
 | 
			
		||||
          if (!result.tuple_index.empty()) {
 | 
			
		||||
            // Note that we don't support nested tuples as of now.
 | 
			
		||||
            result.unsupported_operation = true;
 | 
			
		||||
          } else {
 | 
			
		||||
            result.tuple_index.push_back(
 | 
			
		||||
                Cast<HloGetTupleElementInstruction>(instruction)
 | 
			
		||||
                    ->tuple_index());
 | 
			
		||||
            instruction = instruction->mutable_operand(0);
 | 
			
		||||
          }
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        case HloOpcode::kParameter: {
 | 
			
		||||
          int parameter_number =
 | 
			
		||||
              Cast<HloParameterInstruction>(instruction)->parameter_number();
 | 
			
		||||
          CHECK_EQ(parameter_number, 0);
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (instruction->opcode() == HloOpcode::kParameter) {
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // If the instruction's result is returned from its parent computation with
 | 
			
		||||
  // only forwarding operations, returns the index of the result buffer in the
 | 
			
		||||
  // output parameter tuple.
 | 
			
		||||
  auto get_output_tuple_index =
 | 
			
		||||
      [](HloInstruction* instruction,
 | 
			
		||||
         HloComputation* while_body) -> BufferTupleIndex {
 | 
			
		||||
    BufferTupleIndex result;
 | 
			
		||||
    std::stack<HloInstruction*> to_visit;
 | 
			
		||||
    to_visit.push(instruction);
 | 
			
		||||
    while (!to_visit.empty() && !result.unsupported_operation) {
 | 
			
		||||
      HloInstruction* instruction = to_visit.top();
 | 
			
		||||
      to_visit.pop();
 | 
			
		||||
      for (HloInstruction* user : instruction->users()) {
 | 
			
		||||
        switch (user->opcode()) {
 | 
			
		||||
          case HloOpcode::kBitcast:
 | 
			
		||||
          case HloOpcode::kConvert:
 | 
			
		||||
          case HloOpcode::kReshape:
 | 
			
		||||
          case HloOpcode::kGetTupleElement:
 | 
			
		||||
          case HloOpcode::kTranspose:
 | 
			
		||||
          case HloOpcode::kSlice: {
 | 
			
		||||
            to_visit.push(user);
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
          case HloOpcode::kDynamicSlice: {
 | 
			
		||||
            if (user->operand_index(instruction) == 0) {
 | 
			
		||||
              to_visit.push(user);
 | 
			
		||||
            } else {
 | 
			
		||||
              result.unsupported_operation = true;
 | 
			
		||||
            }
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
          case HloOpcode::kTuple: {
 | 
			
		||||
            if (!result.tuple_index.empty()) {
 | 
			
		||||
              // Note that we don't support nested tuples as of now.
 | 
			
		||||
              result.unsupported_operation = true;
 | 
			
		||||
            } else {
 | 
			
		||||
              result.tuple_index.push_back(user->operand_index(instruction));
 | 
			
		||||
              if (while_body->root_instruction() == user) {
 | 
			
		||||
                if (result.returned_from_computation) {
 | 
			
		||||
                  result.unsupported_operation = true;
 | 
			
		||||
                }
 | 
			
		||||
                result.returned_from_computation = true;
 | 
			
		||||
              } else {
 | 
			
		||||
                to_visit.push(user);
 | 
			
		||||
              }
 | 
			
		||||
            }
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
          default:
 | 
			
		||||
            result.unsupported_operation = true;
 | 
			
		||||
        }
 | 
			
		||||
        if (result.unsupported_operation) {
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // Checks whether any buffer in the list of accumulation contexts is used in
 | 
			
		||||
  // the parent computation except for forwarding uses.
 | 
			
		||||
  auto is_buffer_used =
 | 
			
		||||
      [](absl::Span<const AccumulationContext> accumulation_contexts,
 | 
			
		||||
         HloComputation* while_body_computation) -> bool {
 | 
			
		||||
    std::vector<HloInstruction*> parameter_instructions;
 | 
			
		||||
    absl::c_copy_if(while_body_computation->instructions(),
 | 
			
		||||
                    std::back_inserter(parameter_instructions),
 | 
			
		||||
                    [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                      return instruction->opcode() == HloOpcode::kParameter;
 | 
			
		||||
                    });
 | 
			
		||||
    for (const auto& accumulation : accumulation_contexts) {
 | 
			
		||||
      HloInstruction* accumulation_instruction =
 | 
			
		||||
          accumulation.accumulation_instruction;
 | 
			
		||||
      int tuple_index = accumulation.param_tuple_indices[0];
 | 
			
		||||
      std::stack<HloInstruction*> to_visit;
 | 
			
		||||
      // TODO(b/176437845): simplify the logic below by using
 | 
			
		||||
      // TuplePointsToAnalysis.
 | 
			
		||||
      for (HloInstruction* parameter_instruction : parameter_instructions) {
 | 
			
		||||
        // Iterate over all users of the while body parameter and find all
 | 
			
		||||
        // instructions that use the accumulation buffer, as specified by
 | 
			
		||||
        // tuple_index.
 | 
			
		||||
        // This logic could be simplied by using TuplePointsToAnalysis, which
 | 
			
		||||
        // we leave to a future CL (see TODO above).
 | 
			
		||||
        for (HloInstruction* user : parameter_instruction->users()) {
 | 
			
		||||
          if (auto* gte = DynCast<HloGetTupleElementInstruction>(user)) {
 | 
			
		||||
            if (gte->tuple_index() == tuple_index) {
 | 
			
		||||
              to_visit.push(user);
 | 
			
		||||
            }
 | 
			
		||||
          } else {
 | 
			
		||||
            return true;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      while (!to_visit.empty()) {
 | 
			
		||||
        HloInstruction* instruction = to_visit.top();
 | 
			
		||||
        to_visit.pop();
 | 
			
		||||
        for (HloInstruction* user : instruction->users()) {
 | 
			
		||||
          switch (user->opcode()) {
 | 
			
		||||
            case HloOpcode::kBitcast:
 | 
			
		||||
            case HloOpcode::kConvert:
 | 
			
		||||
            case HloOpcode::kReshape:
 | 
			
		||||
            case HloOpcode::kTranspose:
 | 
			
		||||
              to_visit.push(user);
 | 
			
		||||
              break;
 | 
			
		||||
            case HloOpcode::kDynamicReshape: {
 | 
			
		||||
              if (instruction == user->operand(0)) {
 | 
			
		||||
                to_visit.push(user);
 | 
			
		||||
              } else {
 | 
			
		||||
                return true;
 | 
			
		||||
              }
 | 
			
		||||
              break;
 | 
			
		||||
            }
 | 
			
		||||
            case HloOpcode::kAdd: {
 | 
			
		||||
              if (user != accumulation_instruction) {
 | 
			
		||||
                return true;
 | 
			
		||||
              }
 | 
			
		||||
              break;
 | 
			
		||||
            }
 | 
			
		||||
            default:
 | 
			
		||||
              return true;
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    return false;
 | 
			
		||||
  };
 | 
			
		||||
 | 
			
		||||
  // Finds all accumulation contexts of the given all-reduce instruction if it
 | 
			
		||||
  // is movable.
 | 
			
		||||
  auto get_accumulation_contexts =
 | 
			
		||||
      [&get_origin_tuple_index, &get_output_tuple_index, &is_buffer_used](
 | 
			
		||||
          HloInstruction* all_reduce,
 | 
			
		||||
          HloComputation* while_body) -> MovableAllReduceContext {
 | 
			
		||||
    std::vector<AccumulationContext> accumulation_contexts;
 | 
			
		||||
    // DFS starting from the all-reduce instruction and stops at the first
 | 
			
		||||
    // non-triival uses of the all-reduce result or finds all accmululations
 | 
			
		||||
    // of the all-reduce result.
 | 
			
		||||
    std::stack<HloInstruction*> to_visit;
 | 
			
		||||
    // By default movable unless we find that it's not.
 | 
			
		||||
    bool is_all_reduce_movable = true;
 | 
			
		||||
    to_visit.push(all_reduce);
 | 
			
		||||
 | 
			
		||||
    while (!to_visit.empty() && is_all_reduce_movable) {
 | 
			
		||||
      HloInstruction* instruction = to_visit.top();
 | 
			
		||||
      to_visit.pop();
 | 
			
		||||
      for (HloInstruction* user : instruction->users()) {
 | 
			
		||||
        switch (user->opcode()) {
 | 
			
		||||
          case HloOpcode::kBitcast:
 | 
			
		||||
          case HloOpcode::kConvert:
 | 
			
		||||
          case HloOpcode::kReshape:
 | 
			
		||||
          case HloOpcode::kGetTupleElement:
 | 
			
		||||
          case HloOpcode::kTranspose:
 | 
			
		||||
          case HloOpcode::kSlice: {
 | 
			
		||||
            to_visit.push(user);
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
          case HloOpcode::kDynamicSlice: {
 | 
			
		||||
            if (user->operand_index(instruction) == 0) {
 | 
			
		||||
              to_visit.push(user);
 | 
			
		||||
            } else {
 | 
			
		||||
              is_all_reduce_movable = false;
 | 
			
		||||
              break;
 | 
			
		||||
            }
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
          case HloOpcode::kAdd: {
 | 
			
		||||
            int64 buffer_index = 1 - user->operand_index(instruction);
 | 
			
		||||
            HloInstruction* accumulation_buffer =
 | 
			
		||||
                user->mutable_operand(buffer_index);
 | 
			
		||||
 | 
			
		||||
            auto origin_buffer_tuple_index =
 | 
			
		||||
                get_origin_tuple_index(accumulation_buffer);
 | 
			
		||||
            if (origin_buffer_tuple_index.unsupported_operation) {
 | 
			
		||||
              is_all_reduce_movable = false;
 | 
			
		||||
              break;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            auto output_buffer_tuple_index =
 | 
			
		||||
                get_output_tuple_index(user, while_body);
 | 
			
		||||
            if (!output_buffer_tuple_index.unsupported_operation &&
 | 
			
		||||
                output_buffer_tuple_index.returned_from_computation &&
 | 
			
		||||
                !origin_buffer_tuple_index.tuple_index.empty() &&
 | 
			
		||||
                ContainersEqual(origin_buffer_tuple_index.tuple_index,
 | 
			
		||||
                                output_buffer_tuple_index.tuple_index)) {
 | 
			
		||||
              accumulation_contexts.push_back(AccumulationContext{
 | 
			
		||||
                  user, accumulation_buffer,
 | 
			
		||||
                  std::move(output_buffer_tuple_index.tuple_index)});
 | 
			
		||||
            } else {
 | 
			
		||||
              is_all_reduce_movable = false;
 | 
			
		||||
            }
 | 
			
		||||
            break;
 | 
			
		||||
          }
 | 
			
		||||
          default:
 | 
			
		||||
            is_all_reduce_movable = false;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    if (is_buffer_used(accumulation_contexts, while_body)) {
 | 
			
		||||
      is_all_reduce_movable = false;
 | 
			
		||||
    }
 | 
			
		||||
    return MovableAllReduceContext{is_all_reduce_movable,
 | 
			
		||||
                                   accumulation_contexts};
 | 
			
		||||
  };
 | 
			
		||||
  return get_accumulation_contexts(all_reduce, while_body);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct WhileInitContext {
 | 
			
		||||
  HloInstruction* while_init{nullptr};
 | 
			
		||||
  absl::flat_hash_map<int, HloInstruction*> tuple_index_to_old_buffer;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Creates a new while init instruction, which replaces each accumulation buffer
 | 
			
		||||
// in the given accumulation contexts with a zero-initialized buffer. In other
 | 
			
		||||
// words, we are accumulating all the deltas in the while loop with a zero
 | 
			
		||||
// initial value.
 | 
			
		||||
WhileInitContext CreateNewWhileInit(
 | 
			
		||||
    HloInstruction* old_while_instruction,
 | 
			
		||||
    const HloInstructionMap<std::vector<AccumulationContext>>&
 | 
			
		||||
        all_reduce_to_accumulations) {
 | 
			
		||||
  HloInstruction* old_while_init = old_while_instruction->mutable_operand(0);
 | 
			
		||||
  HloComputation* while_parent = old_while_instruction->parent();
 | 
			
		||||
  std::vector<HloInstruction*> new_while_init_elements(
 | 
			
		||||
      old_while_init->operand_count(), nullptr);
 | 
			
		||||
  for (const auto& all_reduce_and_accumulations_pair :
 | 
			
		||||
       all_reduce_to_accumulations) {
 | 
			
		||||
    const std::vector<AccumulationContext>& accumulations =
 | 
			
		||||
        all_reduce_and_accumulations_pair.second;
 | 
			
		||||
    for (auto& accumulation_context : accumulations) {
 | 
			
		||||
      CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1);
 | 
			
		||||
      int tuple_index = accumulation_context.param_tuple_indices[0];
 | 
			
		||||
      HloInstruction* old_buffer = old_while_init->mutable_operand(tuple_index);
 | 
			
		||||
      HloInstruction* new_buffer = while_parent->AddInstruction(
 | 
			
		||||
          HloInstruction::CreateConstant(LiteralUtil::CreateFromDimensions(
 | 
			
		||||
              old_buffer->shape().element_type(),
 | 
			
		||||
              old_buffer->shape().dimensions())));
 | 
			
		||||
      new_while_init_elements[tuple_index] = new_buffer;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  absl::flat_hash_map<int, HloInstruction*> tuple_index_to_old_buffer;
 | 
			
		||||
  for (int i = 0; i < old_while_init->operand_count(); i++) {
 | 
			
		||||
    if (!new_while_init_elements[i]) {
 | 
			
		||||
      new_while_init_elements[i] = old_while_init->mutable_operand(i);
 | 
			
		||||
    } else {
 | 
			
		||||
      tuple_index_to_old_buffer[i] = old_while_init->mutable_operand(i);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  HloInstruction* new_while_init = while_parent->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateTuple(new_while_init_elements));
 | 
			
		||||
  return WhileInitContext{new_while_init, tuple_index_to_old_buffer};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Creates all the sinked all-reduce instructions in the while instruction's
 | 
			
		||||
// parent computation. Returns a map that maps a tuple index of an accumulation
 | 
			
		||||
// buffer to it's corresponding all-reduce.
 | 
			
		||||
absl::flat_hash_map<int, HloInstruction*> CreateSinkedAllReduces(
 | 
			
		||||
    HloInstruction* new_while_instruction,
 | 
			
		||||
    const HloInstructionMap<std::vector<AccumulationContext>>&
 | 
			
		||||
        all_reduce_to_accumulations,
 | 
			
		||||
    const absl::flat_hash_map<int, HloInstruction*>&
 | 
			
		||||
        tuple_index_to_old_buffer) {
 | 
			
		||||
  HloComputation* while_parent = new_while_instruction->parent();
 | 
			
		||||
  absl::flat_hash_map<int, HloInstruction*> tuple_index_to_new_buffer;
 | 
			
		||||
  for (const auto& all_reduce_and_accumulations_pair :
 | 
			
		||||
       all_reduce_to_accumulations) {
 | 
			
		||||
    HloInstruction* loop_all_reduce = all_reduce_and_accumulations_pair.first;
 | 
			
		||||
    const std::vector<AccumulationContext>& accumulations =
 | 
			
		||||
        all_reduce_and_accumulations_pair.second;
 | 
			
		||||
    for (const auto& accumulation_context : accumulations) {
 | 
			
		||||
      CHECK_EQ(accumulation_context.param_tuple_indices.size(), 1);
 | 
			
		||||
      int tuple_index = accumulation_context.param_tuple_indices[0];
 | 
			
		||||
      const Shape& accumulation_buffer_shape =
 | 
			
		||||
          new_while_instruction->shape().tuple_shapes(tuple_index);
 | 
			
		||||
      HloInstruction* accumulation_buffer =
 | 
			
		||||
          while_parent->AddInstruction(HloInstruction::CreateGetTupleElement(
 | 
			
		||||
              accumulation_buffer_shape, new_while_instruction, tuple_index));
 | 
			
		||||
      HloAllReduceInstruction* old_all_reduce =
 | 
			
		||||
          Cast<HloAllReduceInstruction>(loop_all_reduce);
 | 
			
		||||
      HloInstruction* all_reduce_operand = accumulation_buffer;
 | 
			
		||||
      if (!ShapeUtil::SameElementType(old_all_reduce->shape(),
 | 
			
		||||
                                      accumulation_buffer_shape)) {
 | 
			
		||||
        Shape all_reduce_shape =
 | 
			
		||||
            ShapeUtil::MakeShape(old_all_reduce->shape().element_type(),
 | 
			
		||||
                                 accumulation_buffer_shape.dimensions());
 | 
			
		||||
        all_reduce_operand =
 | 
			
		||||
            while_parent->AddInstruction(HloInstruction::CreateConvert(
 | 
			
		||||
                all_reduce_shape, accumulation_buffer));
 | 
			
		||||
      }
 | 
			
		||||
      HloInstruction* new_all_reduce =
 | 
			
		||||
          while_parent->AddInstruction(HloInstruction::CreateAllReduce(
 | 
			
		||||
              all_reduce_operand->shape(), {all_reduce_operand},
 | 
			
		||||
              old_all_reduce->called_computations()[0],
 | 
			
		||||
              old_all_reduce->replica_groups(),
 | 
			
		||||
              old_all_reduce->constrain_layout(),
 | 
			
		||||
              hlo_query::NextChannelId(*(while_parent->parent())),
 | 
			
		||||
              old_all_reduce->use_global_device_ids()));
 | 
			
		||||
      HloInstruction* all_reduced_delta = new_all_reduce;
 | 
			
		||||
      if (!ShapeUtil::SameElementType(all_reduced_delta->shape(),
 | 
			
		||||
                                      accumulation_buffer_shape)) {
 | 
			
		||||
        all_reduced_delta =
 | 
			
		||||
            while_parent->AddInstruction(HloInstruction::CreateConvert(
 | 
			
		||||
                accumulation_buffer_shape, all_reduced_delta));
 | 
			
		||||
      }
 | 
			
		||||
      CHECK(ContainsKey(tuple_index_to_old_buffer, tuple_index));
 | 
			
		||||
      HloInstruction* old_buffer = tuple_index_to_old_buffer.at(tuple_index);
 | 
			
		||||
      CHECK(ShapeUtil::Equal(old_buffer->shape(), all_reduced_delta->shape()));
 | 
			
		||||
      HloInstruction* add_to_old_buffer =
 | 
			
		||||
          while_parent->AddInstruction(HloInstruction::CreateBinary(
 | 
			
		||||
              all_reduced_delta->shape(), HloOpcode::kAdd, old_buffer,
 | 
			
		||||
              all_reduced_delta));
 | 
			
		||||
      tuple_index_to_new_buffer[tuple_index] = add_to_old_buffer;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return tuple_index_to_new_buffer;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Creates a tuple which is equivalent to the original while instruction's
 | 
			
		||||
// output.
 | 
			
		||||
HloInstruction* CreateNewWhileResult(
 | 
			
		||||
    HloInstruction* new_while_instruction,
 | 
			
		||||
    const absl::flat_hash_map<int, HloInstruction*>&
 | 
			
		||||
        tuple_index_to_new_buffer) {
 | 
			
		||||
  HloComputation* while_parent = new_while_instruction->parent();
 | 
			
		||||
  CHECK(new_while_instruction->shape().IsTuple());
 | 
			
		||||
  std::vector<HloInstruction*> new_while_result_elements(
 | 
			
		||||
      new_while_instruction->shape().tuple_shapes_size(), nullptr);
 | 
			
		||||
  for (int i = 0; i < new_while_result_elements.size(); i++) {
 | 
			
		||||
    if (ContainsKey(tuple_index_to_new_buffer, i)) {
 | 
			
		||||
      new_while_result_elements[i] = tuple_index_to_new_buffer.at(i);
 | 
			
		||||
    } else {
 | 
			
		||||
      HloInstruction* gte =
 | 
			
		||||
          while_parent->AddInstruction(HloInstruction::CreateGetTupleElement(
 | 
			
		||||
              new_while_instruction->shape().tuple_shapes(i),
 | 
			
		||||
              new_while_instruction, i));
 | 
			
		||||
      new_while_result_elements[i] = gte;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  HloInstruction* new_while_result = while_parent->AddInstruction(
 | 
			
		||||
      HloInstruction::CreateTuple(new_while_result_elements));
 | 
			
		||||
  return new_while_result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Creates the sinked all-reduce instructions for all accumulation buffers. The
 | 
			
		||||
// all-reduce outputs are then added to the original accumulation buffers.
 | 
			
		||||
// Creates a tuple that groups the while loop output and the accumulated
 | 
			
		||||
// buffers and replaces all uses of the old while with this new tuple.
 | 
			
		||||
Status AddSinkedAllReducesAndReplaceWhile(
 | 
			
		||||
    HloInstruction* while_instruction,
 | 
			
		||||
    const HloInstructionMap<std::vector<AccumulationContext>>&
 | 
			
		||||
        all_reduce_to_accumulations) {
 | 
			
		||||
  // Note that we create all instructions before replacing and removing any old
 | 
			
		||||
  // instruction. This ensures that we do not accidentally access any deleted
 | 
			
		||||
  // instruction when creating new instructions.
 | 
			
		||||
 | 
			
		||||
  // Step 1) create the new while init instruction, which uses zero-initialized
 | 
			
		||||
  // tensors as the accumulation buffers for the all-reduce.
 | 
			
		||||
  auto new_while_init_context =
 | 
			
		||||
      CreateNewWhileInit(while_instruction, all_reduce_to_accumulations);
 | 
			
		||||
  // Step 2) create the new while instruction.
 | 
			
		||||
  HloInstruction* new_while_instruction =
 | 
			
		||||
      while_instruction->parent()->AddInstruction(HloInstruction::CreateWhile(
 | 
			
		||||
          new_while_init_context.while_init->shape(),
 | 
			
		||||
          while_instruction->while_condition(), while_instruction->while_body(),
 | 
			
		||||
          new_while_init_context.while_init));
 | 
			
		||||
  // Step 3) create the new all-reduce instructions after the while loop.
 | 
			
		||||
  absl::flat_hash_map<int, HloInstruction*> tuple_index_to_new_buffer =
 | 
			
		||||
      CreateSinkedAllReduces(new_while_instruction, all_reduce_to_accumulations,
 | 
			
		||||
                             new_while_init_context.tuple_index_to_old_buffer);
 | 
			
		||||
  // Step 4) create the tuple and replace the old while instruction for all of
 | 
			
		||||
  // its uses.
 | 
			
		||||
  HloInstruction* new_while_result =
 | 
			
		||||
      CreateNewWhileResult(new_while_instruction, tuple_index_to_new_buffer);
 | 
			
		||||
  TF_RETURN_IF_ERROR(while_instruction->parent()->ReplaceInstruction(
 | 
			
		||||
      while_instruction, new_while_result));
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
StatusOr<bool> WhileLoopAllReduceCodeMotion::Run(HloModule* module) {
 | 
			
		||||
  bool is_changed = false;
 | 
			
		||||
  bool run_next_pass = true;
 | 
			
		||||
  // In case of MPMD, all-reduces might be cross-module and should preserve
 | 
			
		||||
  // their channel ID. Do not move all-reduces in this case since the channel
 | 
			
		||||
  // ID might be changed.
 | 
			
		||||
  if (module->config().num_partitions() > 1 &&
 | 
			
		||||
      !module->config().use_spmd_partitioning()) {
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
  // The while instruction's parent could be a while body for another while
 | 
			
		||||
  // loop. We recursively sink the all-reduce through nested while loops if
 | 
			
		||||
  // applicable by repeating this process.
 | 
			
		||||
  while (run_next_pass) {
 | 
			
		||||
    run_next_pass = false;
 | 
			
		||||
    std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
 | 
			
		||||
    // A computation could be the while body of multiple while instructions,
 | 
			
		||||
    // so we start from the computation and find all of its callers that is a
 | 
			
		||||
    // kWhile if there is any.
 | 
			
		||||
    for (HloComputation* computation : module->computations()) {
 | 
			
		||||
      std::vector<HloInstruction*> computation_callers =
 | 
			
		||||
          call_graph->GetComputationCallers(computation);
 | 
			
		||||
      std::vector<HloInstruction*> while_caller_instructions;
 | 
			
		||||
      for (HloInstruction* caller_instruction : computation_callers) {
 | 
			
		||||
        // For simplicity, we only support while instructions whose shape is
 | 
			
		||||
        // tuple.
 | 
			
		||||
        if (caller_instruction->opcode() == HloOpcode::kWhile &&
 | 
			
		||||
            caller_instruction->shape().IsTuple() &&
 | 
			
		||||
            caller_instruction->while_body() == computation) {
 | 
			
		||||
          while_caller_instructions.push_back(caller_instruction);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      // Skip to next computation if this computation is not the while body of
 | 
			
		||||
      // any while instruction.
 | 
			
		||||
      if (while_caller_instructions.empty()) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      std::vector<HloInstruction*> while_body_all_reduces;
 | 
			
		||||
      for (HloInstruction* while_body_instruction :
 | 
			
		||||
           computation->MakeInstructionPostOrder()) {
 | 
			
		||||
        if (auto* all_reduce_instruction =
 | 
			
		||||
                DynCast<HloAllReduceInstruction>(while_body_instruction)) {
 | 
			
		||||
          if (all_reduce_instruction->constrain_layout()) {
 | 
			
		||||
            return false;
 | 
			
		||||
          } else {
 | 
			
		||||
            while_body_all_reduces.push_back(all_reduce_instruction);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      HloInstructionMap<std::vector<AccumulationContext>>
 | 
			
		||||
          all_reduce_to_accumulations;
 | 
			
		||||
      for (HloInstruction* all_reduce : while_body_all_reduces) {
 | 
			
		||||
        auto movable_all_reduce_context =
 | 
			
		||||
            IsAllReduceMovable(all_reduce, computation);
 | 
			
		||||
        if (movable_all_reduce_context.is_movable) {
 | 
			
		||||
          all_reduce_to_accumulations[all_reduce] =
 | 
			
		||||
              std::move(movable_all_reduce_context.accumulation_contexts);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      if (all_reduce_to_accumulations.empty()) {
 | 
			
		||||
        continue;
 | 
			
		||||
      }
 | 
			
		||||
      // For each while instruction calling this computation, create the
 | 
			
		||||
      // corresponding all-reduces after the while loop.
 | 
			
		||||
      for (HloInstruction* while_instruction : while_caller_instructions) {
 | 
			
		||||
        TF_RETURN_IF_ERROR(AddSinkedAllReducesAndReplaceWhile(
 | 
			
		||||
            while_instruction, all_reduce_to_accumulations));
 | 
			
		||||
        is_changed = true;
 | 
			
		||||
        run_next_pass = true;
 | 
			
		||||
      }
 | 
			
		||||
      // At last, remove the old all-reduce instructions in the while body.
 | 
			
		||||
      for (const auto& all_reduce_accumulations_pair :
 | 
			
		||||
           all_reduce_to_accumulations) {
 | 
			
		||||
        HloInstruction* all_reduce = all_reduce_accumulations_pair.first;
 | 
			
		||||
        TF_RETURN_IF_ERROR(computation->ReplaceInstruction(
 | 
			
		||||
            all_reduce, all_reduce->mutable_operand(0)));
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  return is_changed;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
@ -0,0 +1,59 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
 | 
			
		||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/statusor.h"
 | 
			
		||||
 | 
			
		||||
namespace xla {
 | 
			
		||||
 | 
			
		||||
// HLO pass that rewrites while loops to sink all-reduces that are only
 | 
			
		||||
// accumulated into a buffer and not otherwise used in the loop body.
 | 
			
		||||
// An all-reduce instruction can be sinked if its result is only added
 | 
			
		||||
// to a number of accumulation buffers, and the accumulation buffers are not
 | 
			
		||||
// used inside the loop.
 | 
			
		||||
//
 | 
			
		||||
// Pattern before this pass:
 | 
			
		||||
// a = ...
 | 
			
		||||
// while:
 | 
			
		||||
//   b = ...
 | 
			
		||||
//   c = all-reduce(b)
 | 
			
		||||
//   a += c
 | 
			
		||||
// Pattern after this pass:
 | 
			
		||||
// a = ...
 | 
			
		||||
// d = 0
 | 
			
		||||
// while:
 | 
			
		||||
//   b = ...
 | 
			
		||||
//   d += b
 | 
			
		||||
// e = all-reduce(d)
 | 
			
		||||
// a += e
 | 
			
		||||
class WhileLoopAllReduceCodeMotion : public HloModulePass {
 | 
			
		||||
 public:
 | 
			
		||||
  explicit WhileLoopAllReduceCodeMotion() {}
 | 
			
		||||
  ~WhileLoopAllReduceCodeMotion() override = default;
 | 
			
		||||
 | 
			
		||||
  absl::string_view name() const override {
 | 
			
		||||
    static constexpr absl::string_view kName =
 | 
			
		||||
        "while-loop-all-reduce-code-motion";
 | 
			
		||||
    return kName;
 | 
			
		||||
  }
 | 
			
		||||
  StatusOr<bool> Run(HloModule* module) override;
 | 
			
		||||
};
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
 | 
			
		||||
#endif  // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_ALL_REDUCE_CODE_MOTION_H_
 | 
			
		||||
@ -0,0 +1,658 @@
 | 
			
		||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
 | 
			
		||||
 | 
			
		||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
you may not use this file except in compliance with the License.
 | 
			
		||||
You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
See the License for the specific language governing permissions and
 | 
			
		||||
limitations under the License.
 | 
			
		||||
==============================================================================*/
 | 
			
		||||
 | 
			
		||||
#include "tensorflow/compiler/xla/service/while_loop_all_reduce_code_motion.h"
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <iterator>
 | 
			
		||||
 | 
			
		||||
#include "absl/algorithm/container.h"
 | 
			
		||||
#include "absl/strings/string_view.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/service/hlo_verifier.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
 | 
			
		||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
 | 
			
		||||
#include "tensorflow/core/lib/core/status_test_util.h"
 | 
			
		||||
 | 
			
		||||
namespace xla {
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
namespace op = ::xla::testing::opcode_matchers;
 | 
			
		||||
using ::testing::Ne;
 | 
			
		||||
using ::testing::NotNull;
 | 
			
		||||
using ::testing::Property;
 | 
			
		||||
using ::testing::SizeIs;
 | 
			
		||||
 | 
			
		||||
class WhileLoopAllReduceCodeMotionTest : public HloTestBase {};
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulate) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction {
 | 
			
		||||
      %x = f32[] parameter(0)
 | 
			
		||||
      %y = f32[] parameter(1)
 | 
			
		||||
      ROOT %add = f32[] add(f32[] %x, f32[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
 | 
			
		||||
      %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer_init = f32[] constant(0)
 | 
			
		||||
      %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
 | 
			
		||||
      ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  EXPECT_THAT(transformed_while->while_body()->instructions(),
 | 
			
		||||
              Each(Not(op::AllReduce())));
 | 
			
		||||
  HloInstruction* accumulation_buffer =
 | 
			
		||||
      transformed_while->mutable_operand(0)->mutable_operand(3);
 | 
			
		||||
  EXPECT_THAT(accumulation_buffer, op::Constant());
 | 
			
		||||
  HloAllReduceInstruction* moved_all_reduce = DynCast<HloAllReduceInstruction>(
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) {
 | 
			
		||||
                       return Value(instruction, op::AllReduce());
 | 
			
		||||
                     })));
 | 
			
		||||
  ASSERT_THAT(moved_all_reduce, NotNull());
 | 
			
		||||
  EXPECT_THAT(moved_all_reduce->operand(0), op::GetTupleElement());
 | 
			
		||||
  EXPECT_EQ(DynCast<HloGetTupleElementInstruction>(
 | 
			
		||||
                moved_all_reduce->mutable_operand(0))
 | 
			
		||||
                ->tuple_index(),
 | 
			
		||||
            3);
 | 
			
		||||
  EXPECT_THAT(moved_all_reduce->replica_groups(), SizeIs(1));
 | 
			
		||||
  EXPECT_TRUE(
 | 
			
		||||
      std::equal(moved_all_reduce->replica_groups()[0].replica_ids().begin(),
 | 
			
		||||
                 moved_all_reduce->replica_groups()[0].replica_ids().end(),
 | 
			
		||||
                 std::vector<int>{0, 1, 2, 3}.begin()));
 | 
			
		||||
  EXPECT_FALSE(moved_all_reduce->constrain_layout());
 | 
			
		||||
  EXPECT_TRUE(moved_all_reduce->use_global_device_ids());
 | 
			
		||||
  HloComputation* reduction_computation =
 | 
			
		||||
      module->GetComputationWithName("reduction");
 | 
			
		||||
  ASSERT_THAT(reduction_computation, NotNull());
 | 
			
		||||
  EXPECT_EQ(moved_all_reduce->called_computations()[0], reduction_computation);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceSliceAccumulate) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction {
 | 
			
		||||
      %x = f32[] parameter(0)
 | 
			
		||||
      %y = f32[] parameter(1)
 | 
			
		||||
      ROOT %add = f32[] add(f32[] %x, f32[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[3, 1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %gte.4 = f32[1024, 1024] get-tuple-element(%param), index=4
 | 
			
		||||
      %gte.5 = f32[1024, 1024] get-tuple-element(%param), index=5
 | 
			
		||||
      %all-reduce = f32[3, 1024, 1024] all-reduce(f32[3, 1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
 | 
			
		||||
      %slice.0 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[0:1], [0:1024], [0:1024]}
 | 
			
		||||
      %reshape.0 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.0)
 | 
			
		||||
      %slice.1 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[1:2], [0:1024], [0:1024]}
 | 
			
		||||
      %reshape.1 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.1)
 | 
			
		||||
      %slice.2 = f32[1, 1024, 1024] slice(f32[3, 1024, 1024] %all-reduce), slice={[2:3], [0:1024], [0:1024]}
 | 
			
		||||
      %reshape.2 = f32[1024, 1024] reshape(f32[1, 1024, 1024] %slice.2)
 | 
			
		||||
      %accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %reshape.0, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %accumulation.1 = f32[1024, 1024] add(f32[1024, 1024] %reshape.1, f32[1024, 1024] %gte.4)
 | 
			
		||||
      %accumulation.2 = f32[1024, 1024] add(f32[1024, 1024] %reshape.2, f32[1024, 1024] %gte.5)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %accumulation.1, %accumulation.2)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[3, 1024, 1024] parameter(1)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer_init = f32[] constant(0)
 | 
			
		||||
      %accumulation_buffer.0 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %accumulation_buffer.1 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %accumulation_buffer.2 = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %while_init = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[3, 1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, f32[1024, 1024] %accumulation_buffer.1, f32[1024, 1024] %accumulation_buffer.2)
 | 
			
		||||
      ROOT %while = (s32[], s32[], f32[3, 1024, 1024], f32[1024, 1024], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  EXPECT_THAT(transformed_while->while_body()->instructions(),
 | 
			
		||||
              Each(Not(op::AllReduce())));
 | 
			
		||||
  std::vector<HloInstruction*> hoisted_all_reduces;
 | 
			
		||||
  absl::c_copy_if(module->entry_computation()->instructions(),
 | 
			
		||||
                  std::back_inserter(hoisted_all_reduces),
 | 
			
		||||
                  [](HloInstruction* instruction) {
 | 
			
		||||
                    return Value(instruction, op::AllReduce());
 | 
			
		||||
                  });
 | 
			
		||||
  EXPECT_THAT(hoisted_all_reduces, SizeIs(3));
 | 
			
		||||
  ASSERT_THAT(
 | 
			
		||||
      hoisted_all_reduces,
 | 
			
		||||
      Each(Pointee(Property(&HloInstruction::channel_id, Ne(absl::nullopt)))));
 | 
			
		||||
  // Check if added all-reduces have distinct channel IDs.
 | 
			
		||||
  absl::flat_hash_set<int> unique_channel_ids = {
 | 
			
		||||
      hoisted_all_reduces[0]->channel_id().value(),
 | 
			
		||||
      hoisted_all_reduces[1]->channel_id().value(),
 | 
			
		||||
      hoisted_all_reduces[2]->channel_id().value()};
 | 
			
		||||
  EXPECT_THAT(unique_channel_ids, SizeIs(3));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, AllReduceAccumulateUse) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction {
 | 
			
		||||
      %x = f32[] parameter(0)
 | 
			
		||||
      %y = f32[] parameter(1)
 | 
			
		||||
      ROOT %add = f32[] add(f32[] %x, f32[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
 | 
			
		||||
      %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer_init = f32[] constant(0)
 | 
			
		||||
      %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
 | 
			
		||||
      %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
      %gte_while = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024]) %while), index=3
 | 
			
		||||
      ROOT %multiply = f32[1024, 1024] multiply(f32[1024, 1024] %gte_while, f32[1024, 1024] %param.1)
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  EXPECT_THAT(transformed_while->while_body()->instructions(),
 | 
			
		||||
              Each(Not(op::AllReduce())));
 | 
			
		||||
  HloInstruction* new_root = module->entry_computation()->root_instruction();
 | 
			
		||||
  ASSERT_THAT(new_root, op::Multiply());
 | 
			
		||||
  ASSERT_THAT(new_root->operand(0), op::GetTupleElement());
 | 
			
		||||
  ASSERT_THAT(new_root->operand(0)->operand(0), op::Tuple());
 | 
			
		||||
  EXPECT_THAT(new_root->operand(0)->operand(0)->operand(3), op::Add());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, RepeatedlyAccumulatedAllReduce) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction {
 | 
			
		||||
      %x = f32[] parameter(0)
 | 
			
		||||
      %y = f32[] parameter(1)
 | 
			
		||||
      ROOT %add = f32[] add(f32[] %x, f32[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %all-reduce = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
 | 
			
		||||
      %accumulation = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %add.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce, f32[1024, 1024] %accumulation)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %add.0)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer_init = f32[] constant(0)
 | 
			
		||||
      %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
 | 
			
		||||
      ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  EXPECT_FALSE(simplified_loop);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, TypeCastAllReduceAccumulate) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction {
 | 
			
		||||
      %x = bf16[] parameter(0)
 | 
			
		||||
      %y = bf16[] parameter(1)
 | 
			
		||||
      ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2)
 | 
			
		||||
      %all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
 | 
			
		||||
      %convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce)
 | 
			
		||||
      %accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer_init = f32[] constant(0)
 | 
			
		||||
      %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
 | 
			
		||||
      ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  EXPECT_THAT(transformed_while->while_body()->instructions(),
 | 
			
		||||
              Each(Not(op::AllReduce())));
 | 
			
		||||
  HloInstruction* accumulation_buffer =
 | 
			
		||||
      transformed_while->mutable_operand(0)->mutable_operand(3);
 | 
			
		||||
  EXPECT_THAT(accumulation_buffer, op::Constant());
 | 
			
		||||
  HloAllReduceInstruction* moved_all_reduce = DynCast<HloAllReduceInstruction>(
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) {
 | 
			
		||||
                       return Value(instruction, op::AllReduce());
 | 
			
		||||
                     })));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(moved_all_reduce->shape(),
 | 
			
		||||
                               ShapeUtil::MakeShape(BF16, {1024, 1024})));
 | 
			
		||||
 | 
			
		||||
  HloInstruction* add_delta_to_old_buffer =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::Add());
 | 
			
		||||
                     }));
 | 
			
		||||
  ASSERT_THAT(add_delta_to_old_buffer, NotNull());
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->shape(),
 | 
			
		||||
                               ShapeUtil::MakeShape(F32, {1024, 1024})));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(0)->shape(),
 | 
			
		||||
                               ShapeUtil::MakeShape(F32, {1024, 1024})));
 | 
			
		||||
  EXPECT_TRUE(ShapeUtil::Equal(add_delta_to_old_buffer->operand(1)->shape(),
 | 
			
		||||
                               ShapeUtil::MakeShape(F32, {1024, 1024})));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleLoopCalls) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction {
 | 
			
		||||
      %x = bf16[] parameter(0)
 | 
			
		||||
      %y = bf16[] parameter(1)
 | 
			
		||||
      ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %convert.0 = bf16[1024, 1024] convert(f32[1024, 1024] %gte.2)
 | 
			
		||||
      %all-reduce = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %convert.0), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction
 | 
			
		||||
      %convert.1 = f32[1024, 1024] convert(bf16[1024, 1024] %all-reduce)
 | 
			
		||||
      %accumulation = f32[1024, 1024] add(f32[1024, 1024] %convert.1, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer_init = f32[] constant(0)
 | 
			
		||||
      %accumulation_buffer = f32[1024, 1024] broadcast(f32[] %accumulation_buffer_init), dimensions={}
 | 
			
		||||
      %while_init.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer)
 | 
			
		||||
      %while.0 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%while.0), index=3
 | 
			
		||||
      %while_init.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %while.1 = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) while(%while_init.0), condition=%while_condition, body=%while_body
 | 
			
		||||
      ROOT %gte.4 = f32[1024, 1024] get-tuple-element((s32[], s32[], f32[1024, 1024], f32[1024, 1024])%while.1), index=3
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
 | 
			
		||||
                             Matches(op::While())),
 | 
			
		||||
            2);
 | 
			
		||||
  EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
 | 
			
		||||
                             Matches(op::AllReduce())),
 | 
			
		||||
            2);
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  EXPECT_THAT(transformed_while->while_body()->instructions(),
 | 
			
		||||
              Each(Not(op::AllReduce())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, MultipleAllReduceAccumulate) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction.0 {
 | 
			
		||||
      %x = f32[] parameter(0)
 | 
			
		||||
      %y = f32[] parameter(1)
 | 
			
		||||
      ROOT %add = f32[] add(f32[] %x, f32[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %reduction.1 {
 | 
			
		||||
      %x = bf16[] parameter(0)
 | 
			
		||||
      %y = bf16[] parameter(1)
 | 
			
		||||
      ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4
 | 
			
		||||
      %gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5
 | 
			
		||||
      %all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0
 | 
			
		||||
      %accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1
 | 
			
		||||
      %accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %accumulation.1)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %param.2 = bf16[1024, 1024] parameter(2)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer.0 = f32[1024, 1024] constant({...})
 | 
			
		||||
      %accumulation_buffer.1 = bf16[1024, 1024] constant({...})
 | 
			
		||||
      %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1)
 | 
			
		||||
      ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  // Both all-reduces should have been sinked.
 | 
			
		||||
  EXPECT_THAT(transformed_while->while_body()->instructions(),
 | 
			
		||||
              Each(Not(op::AllReduce())));
 | 
			
		||||
  HloInstruction* accumulation_buffer =
 | 
			
		||||
      transformed_while->mutable_operand(0)->mutable_operand(3);
 | 
			
		||||
  EXPECT_THAT(accumulation_buffer, op::Constant());
 | 
			
		||||
  EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
 | 
			
		||||
                             Matches(op::AllReduce())),
 | 
			
		||||
            2);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WhileLoopAllReduceCodeMotionTest, MixMovableAllReduceWithNotMovable) {
 | 
			
		||||
  constexpr absl::string_view kHloModule = R"(
 | 
			
		||||
    HloModule accumulated_all_reduce
 | 
			
		||||
 | 
			
		||||
    %reduction.0 {
 | 
			
		||||
      %x = f32[] parameter(0)
 | 
			
		||||
      %y = f32[] parameter(1)
 | 
			
		||||
      ROOT %add = f32[] add(f32[] %x, f32[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %reduction.1 {
 | 
			
		||||
      %x = bf16[] parameter(0)
 | 
			
		||||
      %y = bf16[] parameter(1)
 | 
			
		||||
      ROOT %add = bf16[] add(bf16[] %x, bf16[] %y)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_condition {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      ROOT result = pred[] compare(%gte.0, %gte.1), direction=LT
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    %while_body {
 | 
			
		||||
      %param = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) parameter(0)
 | 
			
		||||
      %gte.0 = s32[] get-tuple-element(%param), index=0
 | 
			
		||||
      %gte.1 = s32[] get-tuple-element(%param), index=1
 | 
			
		||||
      %gte.2 = f32[1024, 1024] get-tuple-element(%param), index=2
 | 
			
		||||
      %gte.3 = f32[1024, 1024] get-tuple-element(%param), index=3
 | 
			
		||||
      %gte.4 = bf16[1024, 1024] get-tuple-element(%param), index=4
 | 
			
		||||
      %gte.5 = bf16[1024, 1024] get-tuple-element(%param), index=5
 | 
			
		||||
      %all-reduce.0 = f32[1024, 1024] all-reduce(f32[1024, 1024] %gte.2), channel_id=1, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.0
 | 
			
		||||
      %accumulation.0 = f32[1024, 1024] add(f32[1024, 1024] %all-reduce.0, f32[1024, 1024] %gte.3)
 | 
			
		||||
      %all-reduce.1 = bf16[1024, 1024] all-reduce(bf16[1024, 1024] %gte.4), channel_id=2, replica_groups={{0,1,2,3}}, use_global_device_ids=true, to_apply=%reduction.1
 | 
			
		||||
      %accumulation.1 = bf16[1024, 1024] add(bf16[1024, 1024] %all-reduce.1, bf16[1024, 1024] %gte.5)
 | 
			
		||||
      %add.0 = bf16[1024, 1024] add(bf16[1024, 1024] %accumulation.1, bf16[1024, 1024] %gte.4)
 | 
			
		||||
      %constant = s32[] constant(1)
 | 
			
		||||
      %increment_iteration = s32[] add(s32[] %gte.0, s32[] %constant)
 | 
			
		||||
      ROOT %loop_result = (s32[], s32[], f32[1024, 1024], f32[1024, 1024]) tuple(%increment_iteration, %gte.1, %gte.2, %accumulation.0, %gte.4, %add.0)
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ENTRY accumulated_all_reduce {
 | 
			
		||||
      %param.0 = s32[] parameter(0)
 | 
			
		||||
      %param.1 = f32[1024, 1024] parameter(1)
 | 
			
		||||
      %param.2 = bf16[1024, 1024] parameter(2)
 | 
			
		||||
      %constant.0 = s32[] constant(1)
 | 
			
		||||
      %accumulation_buffer.0 = f32[1024, 1024] constant({...})
 | 
			
		||||
      %accumulation_buffer.1 = bf16[1024, 1024] constant({...})
 | 
			
		||||
      %while_init = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) tuple(s32[] %constant.0, s32[] %param.0, f32[1024, 1024] %param.1, f32[1024, 1024] %accumulation_buffer.0, bf16[1024, 1024] %param.2, bf16[1024, 1024] %accumulation_buffer.1)
 | 
			
		||||
      ROOT %while = (s32[], s32[], f32[1024, 1024], f32[1024, 1024], bf16[1024, 1024], bf16[1024, 1024]) while(%while_init), condition=%while_condition, body=%while_body
 | 
			
		||||
    }
 | 
			
		||||
  )";
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
 | 
			
		||||
                          ParseAndReturnVerifiedModule(kHloModule));
 | 
			
		||||
  TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop,
 | 
			
		||||
                          WhileLoopAllReduceCodeMotion{}.Run(module.get()));
 | 
			
		||||
  ASSERT_TRUE(simplified_loop);
 | 
			
		||||
  TF_ASSERT_OK(
 | 
			
		||||
      HloVerifier(/*layout_sensitive=*/false, /*allow_mixed_precision=*/true)
 | 
			
		||||
          .Run(module.get())
 | 
			
		||||
          .status());
 | 
			
		||||
  HloInstruction* transformed_while =
 | 
			
		||||
      *(std::find_if(module->entry_computation()->instructions().begin(),
 | 
			
		||||
                     module->entry_computation()->instructions().end(),
 | 
			
		||||
                     [](HloInstruction* instruction) -> bool {
 | 
			
		||||
                       return Value(instruction, op::While());
 | 
			
		||||
                     }));
 | 
			
		||||
 | 
			
		||||
  ASSERT_THAT(transformed_while, NotNull());
 | 
			
		||||
  // One all-reduce is movable and the other is not movable.
 | 
			
		||||
  EXPECT_EQ(absl::c_count_if(transformed_while->while_body()->instructions(),
 | 
			
		||||
                             Matches(op::AllReduce())),
 | 
			
		||||
            1);
 | 
			
		||||
  HloInstruction* accumulation_buffer =
 | 
			
		||||
      transformed_while->mutable_operand(0)->mutable_operand(3);
 | 
			
		||||
  EXPECT_THAT(accumulation_buffer, op::Constant());
 | 
			
		||||
  EXPECT_EQ(absl::c_count_if(module->entry_computation()->instructions(),
 | 
			
		||||
                             Matches(op::AllReduce())),
 | 
			
		||||
            1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace
 | 
			
		||||
}  // namespace xla
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user