Automated g4 rollback of changelist 200292049

PiperOrigin-RevId: 200309129
This commit is contained in:
A. Unique TensorFlower 2018-06-12 17:17:22 -07:00 committed by TensorFlower Gardener
parent db2f9fd007
commit 213810a0d6
7 changed files with 34 additions and 91 deletions

View File

@ -2123,7 +2123,6 @@ cc_library(
":buffer_liveness",
":buffer_value",
":call_graph",
":copy_insertion",
":flatten_call_graph",
":hlo",
":hlo_dce",
@ -2131,7 +2130,6 @@ cc_library(
":hlo_scheduling",
":logical_buffer",
":tuple_points_to_analysis",
":tuple_simplifier",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -2145,7 +2143,6 @@ tf_cc_test(
name = "hlo_rematerialization_test",
srcs = ["hlo_rematerialization_test.cc"],
deps = [
":flatten_call_graph",
":hlo",
":hlo_matchers",
":hlo_ordering",

View File

@ -613,10 +613,7 @@ class CopyRemover {
VLOG(2) << copy->name() << " is not removable";
return false;
}
if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
VLOG(2) << copy->name() << " is not removable (shape mismatch)";
return false;
}
const CopyNodes& copy_node = copy_map_.at(copy);
ValueNode* src = copy_node.src;
ValueNode* dest = copy_node.dest;
@ -950,6 +947,28 @@ class CopyRemover {
BufferValueTracker buffer_value_tracker_;
};
// Try to remove as many copies from the module as possible without introducing
// live range interference. Copy instructions (identified by their unique id) in
// the set copies_to_exclude are not considered for removal.
Status RemoveUnnecessaryCopies(
const HloOrdering& ordering,
const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module));
CopyRemover copy_remover(*alias_analysis, ordering, module);
XLA_VLOG_LINES(3, copy_remover.ToString());
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
!ContainsKey(copies_to_exclude, instruction->unique_id())) {
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
}
}
}
return Status::OK();
}
// Add copies to address special constraints on the roots of computations not
// related to live range interference:
//
@ -1046,23 +1065,13 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
HloInstruction* instruction = pair.first;
const ShapeTree<bool>& indices_to_copy = pair.second;
ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
std::vector<HloInstruction*> users = instruction->users();
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
instruction->parent()->DeepCopyInstruction(
instruction, &indices_to_copy, &copies_added));
instruction, &indices_to_copy));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
// Special case copies are not eligible for later copy elision passes.
indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
if (has_copy) {
HloInstruction* copy = *copies_added.mutable_element(index);
if (copy != nullptr) {
copy->SetCopyElisionAllowed(false);
}
}
});
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@ -1088,31 +1097,6 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
Status RemoveUnnecessaryCopies(
const HloOrdering& ordering,
const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
MaybeDumpModule("after adding copies to resolve interference", *module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module));
CopyRemover copy_remover(*alias_analysis, ordering, module);
XLA_VLOG_LINES(3, copy_remover.ToString());
std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
for (HloComputation* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kCopy &&
!ContainsKey(copies_to_exclude, instruction->unique_id()) &&
instruction->CopyElisionAllowed()) {
TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
}
}
}
MaybeDumpModule("after removing unnecessary copies", *module);
return Status::OK();
}
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Copy insertion is performed in three steps:
//
@ -1174,10 +1158,14 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
MaybeDumpModule("after adding copies to resolve interference", *module);
DependencyHloOrdering ordering(module);
TF_RETURN_IF_ERROR(
RemoveUnnecessaryCopies(ordering, existing_copies, module));
MaybeDumpModule("after removing unnecessary copies", *module);
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
MaybeDumpModule("after adding special-case copies", *module);

View File

@ -64,13 +64,6 @@ class CopyInsertion : public HloPassInterface {
static StatusOr<bool> AddCopiesForBufferAssignment(HloModule* module);
};
// Try to remove as many copies from the module as possible without introducing
// live range interference. Copy instructions (identified by their unique id) in
// the set copies_to_exclude are not considered for removal.
Status RemoveUnnecessaryCopies(
const HloOrdering& ordering,
const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_COPY_INSERTION_H_

View File

@ -1073,19 +1073,6 @@ class HloInstruction {
// instruction.
void SetupDerivedInstruction(HloInstruction* derived_instruction) const;
// TODO(b/80249101): Remove these methods once HLO scheduling and copy
// insertion are integrated, and we don't need to run a separate pass
// of copy elision anymore.
bool CopyElisionAllowed() const {
CHECK_EQ(HloOpcode::kCopy, opcode_);
return copy_elision_allowed_;
}
void SetCopyElisionAllowed(bool value) {
CHECK_EQ(HloOpcode::kCopy, opcode_);
copy_elision_allowed_ = value;
}
// Returns the size of the slice in the given dimension for a dynamic
// slice node.
//
@ -1608,9 +1595,6 @@ class HloInstruction {
std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers_;
std::vector<int64> gather_window_bounds_;
// Used to tag kCopy instructions that are eligible for copy elision.
bool copy_elision_allowed_ = true;
// The bit sizes for a reduce-precision operation.
int32 exponent_bits_ = 0;
int32 mantissa_bits_ = 0;

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@ -1202,8 +1201,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
StatusOr<bool> HloRematerialization::Run(
HloModule* module, SequentialHloOrdering::HloModuleSequence* sequence,
int64 memory_limit_bytes, RematerializationSizes* sizes,
bool run_copy_elision) {
int64 memory_limit_bytes, RematerializationSizes* sizes) {
// The sequence is constructed entirely by this method.
TF_RET_CHECK(sequence->empty());
@ -1238,15 +1236,6 @@ StatusOr<bool> HloRematerialization::Run(
return size_function_(buffer.shape());
},
scheduler_algorithm_));
if (run_copy_elision) {
// We run a separate pass of copy elision here because the sequential
// ordering from the HLO schedule allows for more copies to be eliminated.
// TODO(b/80249101): Instead of a separate copy elision pass, use the
// ordering from the HLO schedule directly for copy insertion.
SequentialHloOrdering ordering(module, *sequence);
TF_RETURN_IF_ERROR(RemoveUnnecessaryCopies(ordering, {}, module));
}
// Compute peak memory usage of all computations in the module called in a
// sequential context.
call_graph_ = CallGraph::Build(module);
@ -1349,10 +1338,9 @@ StatusOr<bool> HloRematerialization::Run(
int64 memory_limit_bytes, HloModule* hlo_module,
MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes, bool run_copy_elision) {
RematerializationSizes* sizes) {
HloRematerialization remat(scheduler_algorithm, size_function);
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes,
run_copy_elision);
return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes);
}
} // namespace xla

View File

@ -57,12 +57,6 @@ class HloRematerialization {
// sizes: Optional outparam that indicates the peak memory usage of the HLO
// module before/after rematerialization.
//
// run_copy_elision: Enable copy elision. This pass is used to eliminate
// copies that were inserted before HLO scheduling.
//
// TODO(b/80249101): Remove the 'run_copy_elision' parameter when copy
// insertion is integrated with HLO scheduling.
//
// Returns whether any instructions were rematerialized. If memory use is
// already below the given limit then no instructions are rematerialized and
// false is returned.
@ -74,7 +68,7 @@ class HloRematerialization {
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
HloModule* hlo_module, MemorySchedulerAlgorithm scheduler_algorithm,
SequentialHloOrdering::HloModuleSequence* sequence,
RematerializationSizes* sizes, bool run_copy_elision = true);
RematerializationSizes* sizes = nullptr);
protected:
HloRematerialization(MemorySchedulerAlgorithm scheduler_algorithm,
@ -89,8 +83,7 @@ class HloRematerialization {
// contains the memory-minimizing order in which to emit the HLO instructions.
StatusOr<bool> Run(HloModule* module,
SequentialHloOrdering::HloModuleSequence* sequence,
int64 memory_limit, RematerializationSizes* sizes,
bool run_copy_elision);
int64 memory_limit, RematerializationSizes* sizes);
// Rematerializes instructions within the given computation. 'order' is the
// order in which the computation's instructions will be emitted in the

View File

@ -147,7 +147,7 @@ class HloRematerializationTest : public HloTestBase {
TF_EXPECT_OK(verifier().Run(module).status());
return HloRematerialization::RematerializeAndSchedule(
ByteSizeOf, memory_limit_bytes, module, DefaultMemoryScheduler,
sequence, /*sizes=*/nullptr, /*run_copy_elision=*/false);
sequence);
}
// Various shapes used in the canned computations.