[XLA] Resubmit DCE change
PiperOrigin-RevId: 298734102 Change-Id: I50409c476af79cd110ddcc4a81a13a14b764c240
This commit is contained in:
parent
7dcc0a9308
commit
31f7f9cf54
@ -3133,6 +3133,7 @@ cc_library(
|
||||
hdrs = ["hlo_dce.h"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_casting_utils",
|
||||
":hlo_pass",
|
||||
"//tensorflow/compiler/xla:status",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
@ -35,7 +37,8 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
|
||||
StatusOr<bool> HloDCE::RunOnComputation(
|
||||
HloComputation* computation, bool remove_cross_partition_collective_ops) {
|
||||
bool changed = false;
|
||||
VLOG(3) << "Before dce:";
|
||||
XLA_VLOG_LINES(3, computation->ToString());
|
||||
@ -47,7 +50,12 @@ StatusOr<bool> HloDCE::RunOnComputation(HloComputation* computation) {
|
||||
if (instruction != computation->root_instruction() &&
|
||||
instruction->user_count() == 0 &&
|
||||
computation->IsSafelyRemovable(instruction) &&
|
||||
!instruction->HasSideEffect()) {
|
||||
(!instruction->HasSideEffect() ||
|
||||
(remove_cross_partition_collective_ops &&
|
||||
((instruction->opcode() == HloOpcode::kAllReduce &&
|
||||
!Cast<HloAllReduceInstruction>(instruction)->constrain_layout()) ||
|
||||
instruction->opcode() == HloOpcode::kCollectivePermute ||
|
||||
instruction->opcode() == HloOpcode::kAllToAll)))) {
|
||||
dead_roots.push_back(instruction);
|
||||
}
|
||||
}
|
||||
@ -74,8 +82,9 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
|
||||
|
||||
// Run DCE on each computation.
|
||||
for (auto* computation : module->MakeComputationPostOrder()) {
|
||||
TF_ASSIGN_OR_RETURN(bool changed_for_computation,
|
||||
RunOnComputation(computation));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
bool changed_for_computation,
|
||||
RunOnComputation(computation, remove_cross_partition_collective_ops_));
|
||||
changed |= changed_for_computation;
|
||||
}
|
||||
|
||||
|
@ -35,15 +35,23 @@ namespace xla {
|
||||
// instructions cannot be deleted.
|
||||
class HloDCE : public HloModulePass {
|
||||
public:
|
||||
HloDCE() : remove_cross_partition_collective_ops_(false) {}
|
||||
explicit HloDCE(bool remove_cross_partition_collective_ops)
|
||||
: remove_cross_partition_collective_ops_(
|
||||
remove_cross_partition_collective_ops) {}
|
||||
~HloDCE() override {}
|
||||
absl::string_view name() const override { return "dce"; }
|
||||
|
||||
// Run DCE on a computation.
|
||||
static StatusOr<bool> RunOnComputation(HloComputation* computation);
|
||||
StatusOr<bool> RunOnComputation(HloComputation* computation,
|
||||
bool remove_cross_partition_collective_ops);
|
||||
|
||||
// Run the pass on the given module. Returns whether the module was changed
|
||||
// (instructions were removed).
|
||||
StatusOr<bool> Run(HloModule* module) override;
|
||||
|
||||
private:
|
||||
bool remove_cross_partition_collective_ops_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user