[XLA] Resubmit DCE change

PiperOrigin-RevId: 298734102
Change-Id: I50409c476af79cd110ddcc4a81a13a14b764c240
This commit is contained in:
Yuanzhong Xu 2020-03-03 17:38:51 -08:00 committed by TensorFlower Gardener
parent 7dcc0a9308
commit 31f7f9cf54
3 changed files with 23 additions and 5 deletions

View File

@ -3133,6 +3133,7 @@ cc_library(
hdrs = ["hlo_dce.h"], hdrs = ["hlo_dce.h"],
deps = [ deps = [
":hlo", ":hlo",
":hlo_casting_utils",
":hlo_pass", ":hlo_pass",
"//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:status_macros",

View File

@ -21,8 +21,10 @@ limitations under the License.
#include <vector> #include <vector>
#include "absl/container/flat_hash_set.h" #include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status.h"
@ -35,7 +37,8 @@ limitations under the License.
namespace xla { namespace xla {
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) { StatusOr<bool> HloDCE::RunOnComputation(
HloComputation* computation, bool remove_cross_partition_collective_ops) {
bool changed = false; bool changed = false;
VLOG(3) << "Before dce:"; VLOG(3) << "Before dce:";
XLA_VLOG_LINES(3, computation->ToString()); XLA_VLOG_LINES(3, computation->ToString());
@ -47,7 +50,12 @@ StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
if (instruction != computation->root_instruction() && if (instruction != computation->root_instruction() &&
instruction->user_count() == 0 && instruction->user_count() == 0 &&
computation->IsSafelyRemovable(instruction) && computation->IsSafelyRemovable(instruction) &&
!instruction->HasSideEffect()) { (!instruction->HasSideEffect() ||
(remove_cross_partition_collective_ops &&
((instruction->opcode() == HloOpcode::kAllReduce &&
!Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
instruction->opcode() == HloOpcode::kCollectivePermute ||
instruction->opcode() == HloOpcode::kAllToAll)))) {
dead_roots.push_back(instruction); dead_roots.push_back(instruction);
} }
} }
@ -74,8 +82,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
// Run DCE on each computation. // Run DCE on each computation.
for (auto* computation : module->MakeComputationPostOrder()) { for (auto* computation : module->MakeComputationPostOrder()) {
TF_ASSIGN_OR_RETURN(bool changed_for_computation, TF_ASSIGN_OR_RETURN(
RunOnComputation(computation)); bool changed_for_computation,
RunOnComputation(computation, remove_cross_partition_collective_ops_));
changed |= changed_for_computation; changed |= changed_for_computation;
} }

View File

@ -35,15 +35,23 @@ namespace xla {
// instructions cannot be deleted. // instructions cannot be deleted.
class HloDCE : public HloModulePass { class HloDCE : public HloModulePass {
public: public:
HloDCE() : remove_cross_partition_collective_ops_(false) {}
explicit HloDCE(bool remove_cross_partition_collective_ops)
: remove_cross_partition_collective_ops_(
remove_cross_partition_collective_ops) {}
~HloDCE() override {} ~HloDCE() override {}
absl::string_view name() const override { return "dce"; } absl::string_view name() const override { return "dce"; }
// Run DCE on a computation. // Run DCE on a computation.
static StatusOr<bool> RunOnComputation(HloComputation* computation); StatusOr<bool> RunOnComputation(HloComputation* computation,
bool remove_cross_partition_collective_ops);
// Run the pass on the given module. Returns whether the module was changed // Run the pass on the given module. Returns whether the module was changed
// (instructions were removed). // (instructions were removed).
StatusOr<bool> Run(HloModule* module) override; StatusOr<bool> Run(HloModule* module) override;
private:
bool remove_cross_partition_collective_ops_;
}; };
} // namespace xla } // namespace xla