HyoukJoong Lee 3c28370a9c Combine cross-replica / cross-partition AllReduce after SPMD partition
PiperOrigin-RevId: 283610192
Change-Id: I801097d159c39d8137457c55906d455e0ee7733d
2019-12-03 13:32:47 -08:00

197 lines
7.5 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/statusor.h"
namespace xla {
// When the HLO graph contains a cross-module AllReduce (N separate AllReduce
// ops that share the same channel_id for MPMD partitioning, or 1 AllReduce op
// for SPMD partitioning), followed by some simple linear operations, followed
// by a cross-replica AllReduce (also known as cross-replica sum, or CRS), we
// can combine the CMAR and the CRAR, to use an efficient AllReduce
// implementation that fully utilizes the interconnect bandwidth.
//
// Such sequences appear in spatially partitioned models (either MPMD or SPMD).
// This pass must run right after spatial partitioning, when the code is still
// in a single HLO module.
//
// The steps are:
// 1) Find CMARs followed by simple ops followed by CRARs.
// 2) Group CMARs by channel_id. They must all be rewritten. For SPMD
// partitioning, there will only be a single CMAR for each channel_id.
// 3) Prove that the CMAR patterns in each core produce the same result.
// 4) Eliminate the CMAR, and if it feeds an addition/subtraction, divide the
// other operand by the number of spatial partitions.
// 5) Turn the CRAR into an all-core AllReduce.
//
// The pass also handles the case where multiple CMARs lead to the same CRAR,
// and eliminates all CMARs. This graph:
//
// Y
// |
// X CMAR_2 Z
// | \ /
// CMAR_1 +
// \ /
// +
// |
// CRAR
//
// gets rewritten to:
//
// Z num_partitions
// \ /
// Y div
// \ /
// X +
// \ /
// +
// |
// all-core AR
//
class ArCrsCombiner : public HloModulePass {
public:
ArCrsCombiner(int num_spatial_partitions, int num_replicas,
bool spmd_partition)
: num_spatial_partitions_(num_spatial_partitions),
num_replicas_(num_replicas),
spmd_partition_(spmd_partition) {}
absl::string_view name() const override { return "ar-crs-combiner"; }
StatusOr<bool> Run(HloModule* module) override;
// Helper method to allow testing of InstructionsComputeSameValue.
static bool TestInstructionsComputeSameValue(HloInstruction* i1,
HloInstruction* i2);
private:
// We used this struct because multiple ARs could be paired with the same CRS.
// In this case, we want to select the AR that is furthest from the CRS,
// because it makes it easier to eliminate all ARs during RewriteGraph.
struct ArCrsPair {
HloInstruction* ar;
HloInstruction* crs;
// The length of the path from AR to CRS in the HLO graph.
int64 distance;
ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum,
int64 dist)
: ar(all_reduce), crs(cross_replica_sum), distance(dist) {}
string ToString() {
std::vector<string> pieces;
pieces.push_back("(");
HloInstruction* instruction = ar;
while (instruction != crs) {
pieces.push_back(instruction->name());
pieces.push_back(",");
instruction = instruction->users()[0];
}
pieces.push_back(instruction->name());
pieces.push_back(")[id:");
pieces.push_back(std::to_string(*(ar->channel_id())));
pieces.push_back(",dist:");
pieces.push_back(std::to_string(distance));
pieces.push_back("]");
return absl::StrJoin(pieces, "");
}
};
absl::optional<ArCrsCombiner::ArCrsPair> MatchesArCrsPattern(
HloInstruction* instruction);
// If the passed instruction is a while parameter, and the while body is only
// called by a single while instruction, return the while instruction.
absl::optional<HloInstruction*> WhileFromBodyParameter(
HloInstruction* instruction);
// If the passed instruction is a parameter in one of the branch computations,
// and the branch body is only called by a single instruction, return the
// conditional instruction.
absl::optional<HloInstruction*> ConditionalFromBodyParameter(
HloInstruction* instruction);
// Returns a vector of tuple instructions.
// If all instructions that flow to "instruction" are tuples, return them.
// Otherwise, return absl::nullopt. Returns an empty vector if the instruction
// is already in the visited set.
absl::optional<std::vector<HloInstruction*>> GetAllTuples(
HloInstruction* instruction,
absl::flat_hash_set<HloInstruction*>* visited);
// Checks whether two different elements in the same tuple compute the same
// value.
bool TupleElementsComputeSameValue(
HloInstruction* tuple_shaped_instruction, int64 i1, int64 i2,
absl::flat_hash_map<int64, int64>* visited_pairs);
// Returns whether the instructions i1 and i2 can be shown to evaluate to the
// same value. Handling WHILE requires recursion, which may cause us to visit
// the same instruction again. To avoid infinite loops, we pass a cache of
// visited instruction pairs.
bool InstructionsComputeSameValue(
HloInstruction* i1, HloInstruction* i2,
absl::flat_hash_map<int64, int64>* visited_pairs);
// Populates all_reduce_map_.
void GroupAllReducesById(HloModule* module);
// Looks at each AllReduce group in all_reduce_map_, and keeps only the
// groups for which it's safe to move the AllReduce later in the HLO graph.
Status KeepProvablyEqualInstructionGroupsMPMD();
// Same as above, but runs on SPMD partitioned module instead of MPMD.
Status KeepProvablyEqualInstructionGroupsSPMD(HloModule* module);
// Performs the graph rewrite that eliminates the early AllReduce and turns
// the later CRS into an AllReduce.
StatusOr<bool> RewriteGraph();
int num_spatial_partitions_;
int num_replicas_;
// Run this combiner pass assuming the input module is an SPMD partitioned
// module (as opposed to MPMD partitioned).
//
// The main difference between the two w.r.t. this pass is that there would be
// N all-reduce ops for each channel in MPMD mode, whereas there is only 1
// for each channel in SPMD mode. Also we use HloReplicationAnalysis for HLO
// equivalence check in SPMD mode.
bool spmd_partition_;
// Map from all-reduce ids to the AR/CRS pairs.
absl::flat_hash_map<int64, std::vector<ArCrsPair>> all_reduce_map_;
// Map from a CRS instruction to the all-reduce ID of the AR paired with the
// CRS. Sometimes, several ARs in the code could be paired with the same CRS.
// We use this map to pick a single AR/CRS path to rewrite.
absl::flat_hash_map<HloInstruction*, int64> crs_reserved_map_;
std::unique_ptr<CallGraph> call_graph_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_AR_CRS_COMBINER_H_