[XLA] Resubmit DCE change

PiperOrigin-RevId: 298734102
Change-Id: I50409c476af79cd110ddcc4a81a13a14b764c240
This commit is contained in:
Yuanzhong Xu 2020-03-03 17:38:51 -08:00 committed by TensorFlower Gardener
parent 7dcc0a9308
commit 31f7f9cf54
3 changed files with 23 additions and 5 deletions

View File

@ -3133,6 +3133,7 @@ cc_library(
hdrs = ["hlo_dce.h"],
deps = [
":hlo",
":hlo_casting_utils",
":hlo_pass",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",

View File

@ -21,8 +21,10 @@ limitations under the License.
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/status.h"
@ -35,7 +37,8 @@ limitations under the License.
namespace xla {
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
StatusOr<bool> HloDCE::RunOnComputation(
HloComputation* computation, bool remove_cross_partition_collective_ops) {
bool changed = false;
VLOG(3) << "Before dce:";
XLA_VLOG_LINES(3, computation->ToString());
@ -47,7 +50,12 @@ StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
if (instruction != computation->root_instruction() &&
instruction->user_count() == 0 &&
computation->IsSafelyRemovable(instruction) &&
!instruction->HasSideEffect()) {
(!instruction->HasSideEffect() ||
(remove_cross_partition_collective_ops &&
((instruction->opcode() == HloOpcode::kAllReduce &&
!Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
instruction->opcode() == HloOpcode::kCollectivePermute ||
instruction->opcode() == HloOpcode::kAllToAll)))) {
dead_roots.push_back(instruction);
}
}
@ -74,8 +82,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
// Run DCE on each computation.
for (auto* computation : module->MakeComputationPostOrder()) {
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
RunOnComputation(computation));
TF_ASSIGN_OR_RETURN(
bool changed_for_computation,
RunOnComputation(computation, remove_cross_partition_collective_ops_));
changed |= changed_for_computation;
}

View File

@ -35,15 +35,23 @@ namespace xla {
// instructions cannot be deleted.
class HloDCE : public HloModulePass {
public:
HloDCE() : remove_cross_partition_collective_ops_(false) {}
explicit HloDCE(bool remove_cross_partition_collective_ops)
: remove_cross_partition_collective_ops_(
remove_cross_partition_collective_ops) {}
~HloDCE() override {}
absl::string_view name() const override { return "dce"; }
// Run DCE on a computation.
static StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnComputation(HloComputation* computation,
bool remove_cross_partition_collective_ops);
// Run the pass on the given module. Returns whether the module was changed
// (instructions were removed).
StatusOr<bool> Run(HloModule* module) override;
private:
bool remove_cross_partition_collective_ops_;
};
} // namespace xla