[XLA:CPU] Switch parallel_task_assignment to a blacklist so it doesn't parallelize HLOs it doesn't know about
The remaining list is roughly identical to what can go into a loop fusion. Add a test that we don't parallelize allreduce. PiperOrigin-RevId: 313157848 Change-Id: I5e7c85c11d78ba8b9b8a75a15c80eb67cd151064
This commit is contained in:
parent
64a37a9028
commit
be46769cee
@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
||||
// in-place will only touch the updated elements).
|
||||
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
||||
auto opcode = instruction->opcode();
|
||||
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
|
||||
opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall ||
|
||||
opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
|
||||
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
|
||||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
|
||||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
|
||||
opcode == HloOpcode::kSort ||
|
||||
(opcode == HloOpcode::kConvolution &&
|
||||
PotentiallyImplementedAsEigenConvolution(*instruction,
|
||||
target_machine_features_)) ||
|
||||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
|
||||
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
|
||||
instruction->shape().IsTuple()) {
|
||||
if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
|
||||
instruction->shape().IsTuple() || opcode == HloOpcode::kRng) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Consult 'cost_model_' to compute target parallel task count.
|
||||
return cost_model_->GetParallelTaskCount(instruction);
|
||||
// Only allow known good instructions.
|
||||
if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
|
||||
opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
|
||||
opcode == HloOpcode::kDynamicSlice ||
|
||||
opcode == HloOpcode::kDynamicUpdateSlice ||
|
||||
opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
|
||||
opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
|
||||
opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
|
||||
opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
|
||||
opcode == HloOpcode::kTranspose ||
|
||||
(opcode == HloOpcode::kConvolution &&
|
||||
!PotentiallyImplementedAsEigenConvolution(*instruction,
|
||||
target_machine_features_))) {
|
||||
// Consult 'cost_model_' to compute target parallel task count.
|
||||
return cost_model_->GetParallelTaskCount(instruction);
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
|
||||
|
@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) {
|
||||
constexpr char hlo_string[] = R"(
|
||||
HloModule TestTaskParallel_allreduce
|
||||
add {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY CRS {
|
||||
input = f32[1234567] parameter(0)
|
||||
ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
|
||||
EXPECT_FALSE(changed);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user