diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 14afe770ede..225102e6ae6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -142,24 +142,29 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( // in-place will only touch the updated elements). // TODO(b/27458679) Parallelize instructions which are skipped here. auto opcode = instruction->opcode(); - if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || - opcode == HloOpcode::kCall || opcode == HloOpcode::kCustomCall || - opcode == HloOpcode::kDot || opcode == HloOpcode::kSelectAndScatter || - opcode == HloOpcode::kGetTupleElement || opcode == HloOpcode::kBitcast || - opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed || - opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng || - opcode == HloOpcode::kSort || - (opcode == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction, - target_machine_features_)) || - (opcode == HloOpcode::kFusion && !instruction->IsLoopFusion()) || - llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || - instruction->shape().IsTuple()) { + if (llvm_ir::MayBeImplementedAsInPlaceDynamicUpdateSlice(instruction) || + instruction->shape().IsTuple() || opcode == HloOpcode::kRng) { return 1; } - // Consult 'cost_model_' to compute target parallel task count. - return cost_model_->GetParallelTaskCount(instruction); + // Only allow known good instructions. + if (instruction->IsElementwise() || instruction->IsLoopFusion() || + opcode == HloOpcode::kBroadcast || opcode == HloOpcode::kConcatenate || + opcode == HloOpcode::kDynamicSlice || + opcode == HloOpcode::kDynamicUpdateSlice || + opcode == HloOpcode::kGather || opcode == HloOpcode::kIota || + opcode == HloOpcode::kPad || opcode == HloOpcode::kReduce || + opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kReshape || + opcode == HloOpcode::kReverse || opcode == HloOpcode::kSlice || + opcode == HloOpcode::kTranspose || + (opcode == HloOpcode::kConvolution && + !PotentiallyImplementedAsEigenConvolution(*instruction, + target_machine_features_))) { + // Consult 'cost_model_' to compute target parallel task count. + return cost_model_->GetParallelTaskCount(instruction); + } + + return 1; } StatusOr ParallelTaskAssigner::Run(HloModule* module) { diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc index e2c93568b74..e22210a61f2 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment_test.cc @@ -170,5 +170,26 @@ TEST_F(ParallelTaskAssignmentTest, InPlaceDynamicUpdateSliceNotParallelized) { EXPECT_FALSE(changed); } +TEST_F(ParallelTaskAssignmentTest, AllReduceNotParallelized) { + constexpr char hlo_string[] = R"( + HloModule TestTaskParallel_allreduce + add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) + } + + ENTRY CRS { + input = f32[1234567] parameter(0) + ROOT crs = f32[1234567] all-reduce(input), replica_groups={}, to_apply=add + } + )"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr m, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunParallelTaskAssigner(m.get())); + EXPECT_FALSE(changed); +} + } // namespace } // namespace xla