[XLA:CPU] Switch parallel_task_assignment to a blacklist so it doesn't parallelize HLOs it doesn't know about

The remaining list is roughly identical to what can go into a loop fusion. Add
a test that we don't parallelize allreduce.

PiperOrigin-RevId: 313157848
Change-Id: I5e7c85c11d78ba8b9b8a75a15c80eb67cd151064
This commit is contained in:
Benjamin Kramer 2020-05-26 02:54:20 -07:00 committed by TensorFlower Gardener
parent 64a37a9028
commit be46769cee
2 changed files with 41 additions and 15 deletions

View File

@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
// in-place will only touch the updated elements). // in-place will only touch the updated elements).
// TODO(b/27458679) Parallelize instructions which are skipped here. // TODO(b/27458679) Parallelize instructions which are skipped here.
auto opcode = instruction->opcode(); auto opcode = instruction->opcode();
if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || instruction->shape().IsTuple() || opcode == HloOpcode::kRng) {
opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter ||
opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast ||
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
opcode == HloOpcode::kSort ||
(opcode == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_)) ||
(opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) ||
llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) ||
instruction->shape().IsTuple()) {
return 1; return 1;
} }
// Consult 'cost_model_' to compute target parallel task count. // Only allow known good instructions.
return cost_model_->GetParallelTaskCount(instruction); if (instruction->IsElementwise() || instruction->IsLoopFusion() ||
opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate ||
opcode == HloOpcode::kDynamicSlice ||
opcode == HloOpcode::kDynamicUpdateSlice ||
opcode == HloOpcode::kGather || opcode == HloOpcode::kIota ||
opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce ||
opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape ||
opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice ||
opcode == HloOpcode::kTranspose ||
(opcode == HloOpcode::kConvolution &&
!PotentiallyImplementedAsEigenConvolution(*instruction,
target_machine_features_))) {
// Consult 'cost_model_' to compute target parallel task count.
return cost_model_->GetParallelTaskCount(instruction);
}
return 1;
} }
StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) { StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {

View File

@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) {
EXPECT_FALSE(changed); EXPECT_FALSE(changed);
} }
TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) {
constexpr char hlo_string[] = R"(
HloModule TestTaskParallel_allreduce
add {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY CRS {
input = f32[1234567] parameter(0)
ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> m,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get()));
EXPECT_FALSE(changed);
}
} // namespace } // namespace
} // namespace xla } // namespace xla