diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 502f0fa7927..98851fddd2d 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -3133,6 +3133,7 @@ cc_library( hdrs = ["hlo_dce.h"], deps = [ ":hlo", + ":hlo_casting_utils", ":hlo_pass", "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 23b27a5f67d..a573b621c88 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -21,8 +21,10 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/status.h" @@ -35,7 +37,8 @@ limitations under the License. namespace xla { -StatusOr HloDCE::RunOnComputation(HloComputation* computation) { +StatusOr HloDCE::RunOnComputation( + HloComputation* computation, bool remove_cross_partition_collective_ops) { bool changed = false; VLOG(3) << "Before dce:"; XLA_VLOG_LINES(3, computation->ToString()); @@ -47,7 +50,12 @@ StatusOr HloDCE::RunOnComputation(HloComputation* computation) { if (instruction != computation->root_instruction() && instruction->user_count() == 0 && computation->IsSafelyRemovable(instruction) && - !instruction->HasSideEffect()) { + (!instruction->HasSideEffect() || + (remove_cross_partition_collective_ops && + ((instruction->opcode() == HloOpcode::kAllReduce && + !Cast(instruction)->constrain_layout()) || + instruction->opcode() == HloOpcode::kCollectivePermute || + instruction->opcode() == HloOpcode::kAllToAll)))) { dead_roots.push_back(instruction); } } @@ -74,8 +82,9 @@ StatusOr HloDCE::Run(HloModule* module) { // Run DCE on each computation. for (auto* computation : module->MakeComputationPostOrder()) { - TF_ASSIGN_OR_RETURN(bool changed_for_computation, - RunOnComputation(computation)); + TF_ASSIGN_OR_RETURN( + bool changed_for_computation, + RunOnComputation(computation, remove_cross_partition_collective_ops_)); changed |= changed_for_computation; } diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index f22f98868ab..49bb2e3f139 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -35,15 +35,23 @@ namespace xla { // instructions cannot be deleted. class HloDCE : public HloModulePass { public: + HloDCE() : remove_cross_partition_collective_ops_(false) {} + explicit HloDCE(bool remove_cross_partition_collective_ops) + : remove_cross_partition_collective_ops_( + remove_cross_partition_collective_ops) {} ~HloDCE() override {} absl::string_view name() const override { return "dce"; } // Run DCE on a computation. - static StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnComputation(HloComputation* computation, + bool remove_cross_partition_collective_ops); // Run the pass on the given module. Returns whether the module was changed // (instructions were removed). StatusOr Run(HloModule* module) override; + + private: + bool remove_cross_partition_collective_ops_; }; } // namespace xla