Automated g4 rollback of changelist 200292049
PiperOrigin-RevId: 200309129
This commit is contained in:
parent
db2f9fd007
commit
213810a0d6
@ -2123,7 +2123,6 @@ cc_library(
|
|||||||
":buffer_liveness",
|
":buffer_liveness",
|
||||||
":buffer_value",
|
":buffer_value",
|
||||||
":call_graph",
|
":call_graph",
|
||||||
":copy_insertion",
|
|
||||||
":flatten_call_graph",
|
":flatten_call_graph",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_dce",
|
":hlo_dce",
|
||||||
@ -2131,7 +2130,6 @@ cc_library(
|
|||||||
":hlo_scheduling",
|
":hlo_scheduling",
|
||||||
":logical_buffer",
|
":logical_buffer",
|
||||||
":tuple_points_to_analysis",
|
":tuple_points_to_analysis",
|
||||||
":tuple_simplifier",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
@ -2145,7 +2143,6 @@ tf_cc_test(
|
|||||||
name = "hlo_rematerialization_test",
|
name = "hlo_rematerialization_test",
|
||||||
srcs = ["hlo_rematerialization_test.cc"],
|
srcs = ["hlo_rematerialization_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":flatten_call_graph",
|
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_matchers",
|
":hlo_matchers",
|
||||||
":hlo_ordering",
|
":hlo_ordering",
|
||||||
|
@ -613,10 +613,7 @@ class CopyRemover {
|
|||||||
VLOG(2) << copy->name() << " is not removable";
|
VLOG(2) << copy->name() << " is not removable";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
|
|
||||||
VLOG(2) << copy->name() << " is not removable (shape mismatch)";
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
const CopyNodes& copy_node = copy_map_.at(copy);
|
const CopyNodes& copy_node = copy_map_.at(copy);
|
||||||
ValueNode* src = copy_node.src;
|
ValueNode* src = copy_node.src;
|
||||||
ValueNode* dest = copy_node.dest;
|
ValueNode* dest = copy_node.dest;
|
||||||
@ -950,6 +947,28 @@ class CopyRemover {
|
|||||||
BufferValueTracker buffer_value_tracker_;
|
BufferValueTracker buffer_value_tracker_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Try to remove as many copies from the module as possible without introducing
|
||||||
|
// live range interference. Copy instructions (identified by their unique id) in
|
||||||
|
// the set copies_to_exclude are not considered for removal.
|
||||||
|
Status RemoveUnnecessaryCopies(
|
||||||
|
const HloOrdering& ordering,
|
||||||
|
const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
|
||||||
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||||
|
HloAliasAnalysis::Run(module));
|
||||||
|
CopyRemover copy_remover(*alias_analysis, ordering, module);
|
||||||
|
XLA_VLOG_LINES(3, copy_remover.ToString());
|
||||||
|
|
||||||
|
for (HloComputation* computation : module->computations()) {
|
||||||
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
|
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||||
|
!ContainsKey(copies_to_exclude, instruction->unique_id())) {
|
||||||
|
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
// Add copies to address special constraints on the roots of computations not
|
// Add copies to address special constraints on the roots of computations not
|
||||||
// related to live range interference:
|
// related to live range interference:
|
||||||
//
|
//
|
||||||
@ -1046,23 +1065,13 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
|
|||||||
HloInstruction* instruction = pair.first;
|
HloInstruction* instruction = pair.first;
|
||||||
const ShapeTree<bool>& indices_to_copy = pair.second;
|
const ShapeTree<bool>& indices_to_copy = pair.second;
|
||||||
|
|
||||||
ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
|
|
||||||
std::vector<HloInstruction*> users = instruction->users();
|
std::vector<HloInstruction*> users = instruction->users();
|
||||||
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
|
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
|
||||||
instruction->parent()->DeepCopyInstruction(
|
instruction->parent()->DeepCopyInstruction(
|
||||||
instruction, &indices_to_copy, &copies_added));
|
instruction, &indices_to_copy));
|
||||||
for (HloInstruction* user : users) {
|
for (HloInstruction* user : users) {
|
||||||
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
|
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
|
||||||
}
|
}
|
||||||
// Special case copies are not eligible for later copy elision passes.
|
|
||||||
indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
|
|
||||||
if (has_copy) {
|
|
||||||
HloInstruction* copy = *copies_added.mutable_element(index);
|
|
||||||
if (copy != nullptr) {
|
|
||||||
copy->SetCopyElisionAllowed(false);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if (instruction == instruction->parent()->root_instruction()) {
|
if (instruction == instruction->parent()->root_instruction()) {
|
||||||
instruction->parent()->set_root_instruction(deep_copy);
|
instruction->parent()->set_root_instruction(deep_copy);
|
||||||
}
|
}
|
||||||
@ -1088,31 +1097,6 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status RemoveUnnecessaryCopies(
|
|
||||||
const HloOrdering& ordering,
|
|
||||||
const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
|
|
||||||
MaybeDumpModule("after adding copies to resolve interference", *module);
|
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
|
||||||
HloAliasAnalysis::Run(module));
|
|
||||||
CopyRemover copy_remover(*alias_analysis, ordering, module);
|
|
||||||
XLA_VLOG_LINES(3, copy_remover.ToString());
|
|
||||||
|
|
||||||
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
|
|
||||||
for (HloComputation* computation : module->computations()) {
|
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
|
||||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
|
||||||
!ContainsKey(copies_to_exclude, instruction->unique_id()) &&
|
|
||||||
instruction->CopyElisionAllowed()) {
|
|
||||||
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
MaybeDumpModule("after removing unnecessary copies", *module);
|
|
||||||
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||||
// Copy insertion is performed in three steps:
|
// Copy insertion is performed in three steps:
|
||||||
//
|
//
|
||||||
@ -1174,10 +1158,14 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
|||||||
|
|
||||||
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
|
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
|
||||||
|
|
||||||
|
MaybeDumpModule("after adding copies to resolve interference", *module);
|
||||||
|
|
||||||
DependencyHloOrdering ordering(module);
|
DependencyHloOrdering ordering(module);
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
RemoveUnnecessaryCopies(ordering, existing_copies, module));
|
RemoveUnnecessaryCopies(ordering, existing_copies, module));
|
||||||
|
|
||||||
|
MaybeDumpModule("after removing unnecessary copies", *module);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
||||||
|
|
||||||
MaybeDumpModule("after adding special-case copies", *module);
|
MaybeDumpModule("after adding special-case copies", *module);
|
||||||
|
@ -64,13 +64,6 @@ class CopyInsertion : public HloPassInterface {
|
|||||||
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
|
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Try to remove as many copies from the module as possible without introducing
|
|
||||||
// live range interference. Copy instructions (identified by their unique id) in
|
|
||||||
// the set copies_to_exclude are not considered for removal.
|
|
||||||
Status RemoveUnnecessaryCopies(
|
|
||||||
const HloOrdering& ordering,
|
|
||||||
const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module);
|
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_
|
||||||
|
@ -1073,19 +1073,6 @@ class HloInstruction {
|
|||||||
// instruction.
|
// instruction.
|
||||||
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
|
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
|
||||||
|
|
||||||
// TODO(b/80249101): Remove these methods once HLO scheduling and copy
|
|
||||||
// insertion are integrated, and we don't need to run a separate pass
|
|
||||||
// of copy elision anymore.
|
|
||||||
bool CopyElisionAllowed() const {
|
|
||||||
CHECK_EQ(HloOpcode::kCopy, opcode_);
|
|
||||||
return copy_elision_allowed_;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetCopyElisionAllowed(bool value) {
|
|
||||||
CHECK_EQ(HloOpcode::kCopy, opcode_);
|
|
||||||
copy_elision_allowed_ = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the size of the slice in the given dimension for a dynamic
|
// Returns the size of the slice in the given dimension for a dynamic
|
||||||
// slice node.
|
// slice node.
|
||||||
//
|
//
|
||||||
@ -1608,9 +1595,6 @@ class HloInstruction {
|
|||||||
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
|
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
|
||||||
std::vector<int64> gather_window_bounds_;
|
std::vector<int64> gather_window_bounds_;
|
||||||
|
|
||||||
// Used to tag kCopy instructions that are eligible for copy elision.
|
|
||||||
bool copy_elision_allowed_ = true;
|
|
||||||
|
|
||||||
// The bit sizes for a reduce-precision operation.
|
// The bit sizes for a reduce-precision operation.
|
||||||
int32 exponent_bits_ = 0;
|
int32 exponent_bits_ = 0;
|
||||||
int32 mantissa_bits_ = 0;
|
int32 mantissa_bits_ = 0;
|
||||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
#include "tensorflow/compiler/xla/service/buffer_value.h"
|
||||||
#include "tensorflow/compiler/xla/service/copy_insertion.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
#include "tensorflow/compiler/xla/service/hlo_dce.h"
|
||||||
@ -1202,8 +1201,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
|||||||
|
|
||||||
StatusOr<bool> HloRematerialization::Run(
|
StatusOr<bool> HloRematerialization::Run(
|
||||||
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
|
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
|
||||||
int64 memory_limit_bytes, RematerializationSizes* sizes,
|
int64 memory_limit_bytes, RematerializationSizes* sizes) {
|
||||||
bool run_copy_elision) {
|
|
||||||
// The sequence is constructed entirely by this method.
|
// The sequence is constructed entirely by this method.
|
||||||
TF_RET_CHECK(sequence->empty());
|
TF_RET_CHECK(sequence->empty());
|
||||||
|
|
||||||
@ -1238,15 +1236,6 @@ StatusOr<bool> HloRematerialization::Run(
|
|||||||
return size_function_(buffer.shape());
|
return size_function_(buffer.shape());
|
||||||
},
|
},
|
||||||
scheduler_algorithm_));
|
scheduler_algorithm_));
|
||||||
if (run_copy_elision) {
|
|
||||||
// We run a separate pass of copy elision here because the sequential
|
|
||||||
// ordering from the HLO schedule allows for more copies to be eliminated.
|
|
||||||
// TODO(b/80249101): Instead of a separate copy elision pass, use the
|
|
||||||
// ordering from the HLO schedule directly for copy insertion.
|
|
||||||
SequentialHloOrdering ordering(module, *sequence);
|
|
||||||
TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compute peak memory usage of all computations in the module called in a
|
// Compute peak memory usage of all computations in the module called in a
|
||||||
// sequential context.
|
// sequential context.
|
||||||
call_graph_ = CallGraph::Build(module);
|
call_graph_ = CallGraph::Build(module);
|
||||||
@ -1349,10 +1338,9 @@ StatusOr<bool> HloRematerialization::Run(
|
|||||||
int64 memory_limit_bytes, HloModule* hlo_module,
|
int64 memory_limit_bytes, HloModule* hlo_module,
|
||||||
MemorySchedulerAlgorithm scheduler_algorithm,
|
MemorySchedulerAlgorithm scheduler_algorithm,
|
||||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||||
RematerializationSizes* sizes, bool run_copy_elision) {
|
RematerializationSizes* sizes) {
|
||||||
HloRematerialization remat(scheduler_algorithm, size_function);
|
HloRematerialization remat(scheduler_algorithm, size_function);
|
||||||
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
|
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
|
||||||
run_copy_elision);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
@ -57,12 +57,6 @@ class HloRematerialization {
|
|||||||
// sizes: Optional outparam that indicates the peak memory usage of the HLO
|
// sizes: Optional outparam that indicates the peak memory usage of the HLO
|
||||||
// module before/after rematerialization.
|
// module before/after rematerialization.
|
||||||
//
|
//
|
||||||
// run_copy_elision: Enable copy elision. This pass is used to eliminate
|
|
||||||
// copies that were inserted before HLO scheduling.
|
|
||||||
//
|
|
||||||
// TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
|
|
||||||
// insertion is integrated with HLO scheduling.
|
|
||||||
//
|
|
||||||
// Returns whether any instructions were rematerialized. If memory use is
|
// Returns whether any instructions were rematerialized. If memory use is
|
||||||
// already below the given limit then no instructions are rematerialized and
|
// already below the given limit then no instructions are rematerialized and
|
||||||
// false is returned.
|
// false is returned.
|
||||||
@ -74,7 +68,7 @@ class HloRematerialization {
|
|||||||
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
|
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
|
||||||
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
|
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
|
||||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||||
RematerializationSizes* sizes, bool run_copy_elision = true);
|
RematerializationSizes* sizes = nullptr);
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
|
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
|
||||||
@ -89,8 +83,7 @@ class HloRematerialization {
|
|||||||
// contains the memory-minimizing order in which to emit the HLO instructions.
|
// contains the memory-minimizing order in which to emit the HLO instructions.
|
||||||
StatusOr<bool> Run(HloModule* module,
|
StatusOr<bool> Run(HloModule* module,
|
||||||
SequentialHloOrdering::HloModuleSequence* sequence,
|
SequentialHloOrdering::HloModuleSequence* sequence,
|
||||||
int64 memory_limit, RematerializationSizes* sizes,
|
int64 memory_limit, RematerializationSizes* sizes);
|
||||||
bool run_copy_elision);
|
|
||||||
|
|
||||||
// Rematerializes instructions within the given computation. 'order' is the
|
// Rematerializes instructions within the given computation. 'order' is the
|
||||||
// order in which the computation's instructions will be emitted in the
|
// order in which the computation's instructions will be emitted in the
|
||||||
|
@ -147,7 +147,7 @@ class HloRematerializationTest : public HloTestBase {
|
|||||||
TF_EXPECT_OK(verifier().Run(module).status());
|
TF_EXPECT_OK(verifier().Run(module).status());
|
||||||
return HloRematerialization::RematerializeAndSchedule(
|
return HloRematerialization::RematerializeAndSchedule(
|
||||||
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
|
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
|
||||||
sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false);
|
sequence);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Various shapes used in the canned computations.
|
// Various shapes used in the canned computations.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user