[XLA] Resubmit DCE change
PiperOrigin-RevId: 298734102 Change-Id: I50409c476af79cd110ddcc4a81a13a14b764c240
This commit is contained in:
parent
7dcc0a9308
commit
31f7f9cf54
@ -3133,6 +3133,7 @@ cc_library(
|
|||||||
hdrs = ["hlo_dce.h"],
|
hdrs = ["hlo_dce.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
|
":hlo_casting_utils",
|
||||||
":hlo_pass",
|
":hlo_pass",
|
||||||
"//tensorflow/compiler/xla:status",
|
"//tensorflow/compiler/xla:status",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
|||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "absl/container/flat_hash_set.h"
|
#include "absl/container/flat_hash_set.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
#include "tensorflow/compiler/xla/status.h"
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
@ -35,7 +37,8 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
|
|
||||||
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
|
StatusOr<bool> HloDCE::RunOnComputation(
|
||||||
|
HloComputation* computation, bool remove_cross_partition_collective_ops) {
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
VLOG(3) << "Before dce:";
|
VLOG(3) << "Before dce:";
|
||||||
XLA_VLOG_LINES(3, computation->ToString());
|
XLA_VLOG_LINES(3, computation->ToString());
|
||||||
@ -47,7 +50,12 @@ StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
|
|||||||
if (instruction != computation->root_instruction() &&
|
if (instruction != computation->root_instruction() &&
|
||||||
instruction->user_count() == 0 &&
|
instruction->user_count() == 0 &&
|
||||||
computation->IsSafelyRemovable(instruction) &&
|
computation->IsSafelyRemovable(instruction) &&
|
||||||
!instruction->HasSideEffect()) {
|
(!instruction->HasSideEffect() ||
|
||||||
|
(remove_cross_partition_collective_ops &&
|
||||||
|
((instruction->opcode() == HloOpcode::kAllReduce &&
|
||||||
|
!Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
|
||||||
|
instruction->opcode() == HloOpcode::kCollectivePermute ||
|
||||||
|
instruction->opcode() == HloOpcode::kAllToAll)))) {
|
||||||
dead_roots.push_back(instruction);
|
dead_roots.push_back(instruction);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -74,8 +82,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
|||||||
|
|
||||||
// Run DCE on each computation.
|
// Run DCE on each computation.
|
||||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||||
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
|
TF_ASSIGN_OR_RETURN(
|
||||||
RunOnComputation(computation));
|
bool changed_for_computation,
|
||||||
|
RunOnComputation(computation, remove_cross_partition_collective_ops_));
|
||||||
changed |= changed_for_computation;
|
changed |= changed_for_computation;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -35,15 +35,23 @@ namespace xla {
|
|||||||
// instructions cannot be deleted.
|
// instructions cannot be deleted.
|
||||||
class HloDCE : public HloModulePass {
|
class HloDCE : public HloModulePass {
|
||||||
public:
|
public:
|
||||||
|
HloDCE() : remove_cross_partition_collective_ops_(false) {}
|
||||||
|
explicit HloDCE(bool remove_cross_partition_collective_ops)
|
||||||
|
: remove_cross_partition_collective_ops_(
|
||||||
|
remove_cross_partition_collective_ops) {}
|
||||||
~HloDCE() override {}
|
~HloDCE() override {}
|
||||||
absl::string_view name() const override { return "dce"; }
|
absl::string_view name() const override { return "dce"; }
|
||||||
|
|
||||||
// Run DCE on a computation.
|
// Run DCE on a computation.
|
||||||
static StatusOr<bool> RunOnComputation(HloComputation* computation);
|
StatusOr<bool> RunOnComputation(HloComputation* computation,
|
||||||
|
bool remove_cross_partition_collective_ops);
|
||||||
|
|
||||||
// Run the pass on the given module. Returns whether the module was changed
|
// Run the pass on the given module. Returns whether the module was changed
|
||||||
// (instructions were removed).
|
// (instructions were removed).
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool remove_cross_partition_collective_ops_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user