[XLA:CPU] Switch parallel_task_assignment to a blacklist so it doesn't parallelize HLOs it doesn't know about
The remaining list is roughly identical to what can go into a loop fusion. Add a test that we don't parallelize allreduce. PiperOrigin-RevId: 313157848 Change-Id: I5e7c85c11d78ba8b9b8a75a15c80eb67cd151064
This commit is contained in:
parent
64a37a9028
commit
be46769cee
@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
|||||||
// in-place will only touch the updated elements).
|
// in-place will only touch the updated elements).
|
||||||
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
||||||
auto opcode = instruction->opcode();
|
auto opcode = instruction->opcode();
|
||||||
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
|
if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
|
||||||
opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall ||
|
instruction->shape().IsTuple() || opcode == HloOpcode::kRng) {
|
||||||
opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
|
|
||||||
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
|
|
||||||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
|
|
||||||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
|
|
||||||
opcode == HloOpcode::kSort ||
|
|
||||||
(opcode == HloOpcode::kConvolution &&
|
|
||||||
PotentiallyImplementedAsEigenConvolution(*instruction,
|
|
||||||
target_machine_features_)) ||
|
|
||||||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
|
|
||||||
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
|
|
||||||
instruction->shape().IsTuple()) {
|
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Consult 'cost_model_' to compute target parallel task count.
|
// Only allow known good instructions.
|
||||||
return cost_model_->GetParallelTaskCount(instruction);
|
if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
|
||||||
|
opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
|
||||||
|
opcode == HloOpcode::kDynamicSlice ||
|
||||||
|
opcode == HloOpcode::kDynamicUpdateSlice ||
|
||||||
|
opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
|
||||||
|
opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
|
||||||
|
opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
|
||||||
|
opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
|
||||||
|
opcode == HloOpcode::kTranspose ||
|
||||||
|
(opcode == HloOpcode::kConvolution &&
|
||||||
|
!PotentiallyImplementedAsEigenConvolution(*instruction,
|
||||||
|
target_machine_features_))) {
|
||||||
|
// Consult 'cost_model_' to compute target parallel task count.
|
||||||
|
return cost_model_->GetParallelTaskCount(instruction);
|
||||||
|
}
|
||||||
|
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
|
StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
|
||||||
|
@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
|
|||||||
EXPECT_FALSE(changed);
|
EXPECT_FALSE(changed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) {
|
||||||
|
constexpr char hlo_string[] = R"(
|
||||||
|
HloModule TestTaskParallel_allreduce
|
||||||
|
add {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY CRS {
|
||||||
|
input = f32[1234567] parameter(0)
|
||||||
|
ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
|
||||||
|
ParseAndReturnVerifiedModule(hlo_string));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
|
||||||
|
EXPECT_FALSE(changed);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
Reference in New Issue
Block a user