[XLA:CPU] Switch parallel_task_assignment to a blacklist so it doesn't parallelize HLOs it doesn't know about

The remaining list is roughly identical to what can go into a loop fusion. Add
a test that we don't parallelize allreduce.

PiperOrigin-RevId: 313157848
Change-Id: I5e7c85c11d78ba8b9b8a75a15c80eb67cd151064
This commit is contained in:
Benjamin Kramer 2020-05-26 02:54:20 -07:00 committed by TensorFlower Gardener
parent 64a37a9028
commit be46769cee
2 changed files with 41 additions and 15 deletions

View File

@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
// in-place will only touch the updated elements).
// TODO(b/27458679) Parallelize instructions which are skipped here.
auto opcode = instruction->opcode();
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall ||
opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
instruction->shape().IsTuple()) {
if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
instruction->shape().IsTuple() || opcode == HloOpcode::kRng) {
return 1;
}
// Consult 'cost_model_' to compute target parallel task count.
return cost_model_->GetParallelTaskCount(instruction);
// Only allow known good instructions.
if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
opcode == HloOpcode::kDynamicSlice ||
opcode == HloOpcode::kDynamicUpdateSlice ||
opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
opcode == HloOpcode::kTranspose ||
(opcode == HloOpcode::kConvolution &&
!PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_))) {
// Consult 'cost_model_' to compute target parallel task count.
return cost_model_->GetParallelTaskCount(instruction);
}
return 1;
}
StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {

View File

@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
EXPECT_FALSE(changed);
}
TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) {
constexpr char hlo_string[] = R"(
HloModule TestTaskParallel_allreduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[1234567] parameter(0)
ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
EXPECT_FALSE(changed);
}
} // namespace
} // namespace xla