Teach MemorySpaceAssignment to find buffers to keep in alternate memory across program scopes.

PiperOrigin-RevId: 302965410
Change-Id: Ie07d1c77add83740d56078b42b966ddd5a6c81d3
This commit is contained in:
A. Unique TensorFlower 2020-03-25 14:07:26 -07:00 committed by TensorFlower Gardener
parent 5943793102
commit 35a3591b3e
6 changed files with 397 additions and 0 deletions

View File

@ -358,6 +358,11 @@ message DynamicParameterBindingProto {
repeated Binding entries = 1;
}
message CrossProgramPrefetch {
int64 parameter = 1;
repeated int64 index = 2;
}
// Serialization of HloModule.
message HloModuleProto {
string name = 1;
@ -381,6 +386,8 @@ message HloModuleProto {
HloInputOutputAliasProto input_output_alias = 8;
DynamicParameterBindingProto dynamic_parameter_binding = 9;
repeated CrossProgramPrefetch cross_program_prefetches = 10;
}
// Serialization of LogicalBuffer.

View File

@ -263,6 +263,13 @@ HloModuleProto HloModule::ToProto() const {
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
*proto.mutable_dynamic_parameter_binding() =
dynamic_parameter_binding().ToProto();
for (auto [parameter, indices] : CrossProgramPrefetches()) {
auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
prefetch->set_parameter(parameter);
for (auto index : indices) {
prefetch->add_index(index);
}
}
return proto;
}
@ -389,6 +396,12 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
}
for (auto prefetch : proto.cross_program_prefetches()) {
module->AddCrossProgramPrefetch(
prefetch.parameter(),
ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
}
return std::move(module);
}
@ -669,6 +682,9 @@ std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
}
TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
}
for (auto [parameter, index] : CrossProgramPrefetches()) {
module->AddCrossProgramPrefetch(parameter, index);
}
return module;
}

View File

@ -345,6 +345,17 @@ class HloModule {
spmd_output_sharding_ = sharding;
}
// Add a program argument to be prefetched across programs.
void AddCrossProgramPrefetch(int64 parameter, const ShapeIndex& index) {
cross_program_prefetches_.emplace_back(parameter, index);
}
// Get the list of program arguments to be prefetch across programs.
const absl::Span<const std::pair<int64, ShapeIndex>> CrossProgramPrefetches()
const {
return cross_program_prefetches_;
}
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
@ -392,6 +403,9 @@ class HloModule {
// The HLO sharding of the entry computation's output (root) for
// SPMD-partitioned programs.
absl::optional<HloSharding> spmd_output_sharding_;
// Arguments to be prefetched across programs.
std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
};
} // namespace xla

View File

@ -355,6 +355,17 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
continue;
}
HloInstruction* inst = interval.buffer->instruction();
HloModule* module = inst->GetModule();
// Don't intra-program prefetch a cross program prefetch
if (inst->opcode() == HloOpcode::kParameter &&
absl::c_count(module->CrossProgramPrefetches(),
std::make_pair(inst->parameter_number(),
interval.buffer->index())) > 0) {
continue;
}
auto colocated_intervals = GetSortedColocatedIntervals(interval);
if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
@ -561,6 +572,52 @@ AlternateMemoryBestFitHeap::GetLiveAllocationAt(
return nullptr;
}
void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
HloModule* module, absl::optional<BufferInterval> prefetch_candidate) {
if (!prefetch_candidate) {
return;
}
ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate);
if (chunk_candidate.chunk.offset != 0 ||
chunk_candidate.heap_size > available_heap_size()) {
LOG(WARNING)
<< "Could not allocate preferred memory for cross program prefetch";
return;
}
AddToPendingChunks(*prefetch_candidate, chunk_candidate);
const HloValue* buffer = prefetch_candidate->buffer;
int64 parameter = buffer->instruction()->parameter_number();
module->AddCrossProgramPrefetch(parameter, buffer->index());
allocation_sequence_list_->push_back({buffer, {}});
MemorySpaceAssignment::AllocationSequence& allocations =
allocation_sequence_list_->back().sequence;
allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
prefetch_candidate->start, prefetch_candidate->end));
// Sort the uses by the use time.
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
auto uses = buffer->uses();
auto first_use =
absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) {
return instruction_schedule.at(lhs.instruction) <
instruction_schedule.at(rhs.instruction);
});
int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction);
AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
chunk_candidate.chunk, prefetch_candidate->start,
prefetch_candidate->end, latest_prefetch_time, &allocations);
absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
pending_chunks_.clear();
pending_async_copies_.clear();
}
void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
// Go through the parameters and outputs and pin them to the corresponding
// memory by adding a required assignment.
@ -1207,6 +1264,90 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
};
}
namespace {
bool LooksLikeAnActivation(const HloInstruction* inst) {
for (HloInstruction* user : inst->users()) {
switch (user->opcode()) {
case HloOpcode::kConvolution:
case HloOpcode::kDot:
if (user->operand(0) == inst) {
return true;
}
break;
case HloOpcode::kGather:
if (user->operand(1) == inst) {
return true;
}
break;
case HloOpcode::kFusion:
for (int i = 0; i < user->operand_count(); ++i) {
if (user->operand(i) == inst &&
LooksLikeAnActivation(user->fused_parameter(i))) {
return true;
}
}
break;
default:
return true;
}
}
return false;
}
bool IsCrossProgramPrefetchCandidate(
const HloValue& value, const MemorySpaceAssignment::Options& options) {
return value.instruction()->parent() ==
value.instruction()->GetModule()->entry_computation() &&
value.instruction()->opcode() == HloOpcode::kParameter &&
value.index().size() == 1 && value.shape().IsArray() &&
!value.uses().empty() &&
options.size_fn(value) <= options.max_size_in_bytes &&
absl::c_all_of(value.uses(), [&](const HloUse& use) {
const HloInstruction* gte =
use.instruction->operand(use.operand_number);
return gte->opcode() == HloOpcode::kGetTupleElement &&
!LooksLikeAnActivation(gte);
});
}
absl::optional<MemorySpaceAssignment::BufferInterval>
FindCrossProgramPrefetchCandidate(
const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range,
const MemorySpaceAssignment::Options& options) {
std::vector<MemorySpaceAssignment::BufferInterval> candidates;
for (HloValue* value : alias_analysis.dataflow_analysis().values()) {
if (IsCrossProgramPrefetchCandidate(*value, options)) {
MemorySpaceAssignment::BufferInterval interval;
interval.buffer = value;
interval.size = options.size_fn(*value);
interval.start = 0;
interval.end = hlo_live_range.schedule_end_time();
interval.need_allocation = true;
candidates.emplace_back(interval);
}
}
// The buffer_interval_compare ought to do a good job picking the most
// appropriate buffer to cross program prefetch, but empirically, it makes
// worse choices than just picking the largest buffer.
// TODO(b/152421603): Investigate.
auto size_compare = [](const auto& x, const auto& y) {
return x.size < y.size;
};
auto& compare = options.default_cross_program_prefetch_heuristic &&
options.buffer_interval_compare
? *options.buffer_interval_compare
: size_compare;
auto best_candidate = absl::c_max_element(candidates, compare);
if (best_candidate == candidates.end()) {
return absl::nullopt;
}
return *best_candidate;
}
} // namespace
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
MemorySpaceAssignment::Run(HloModule* module,
const HloLiveRange& hlo_live_range,
@ -1222,6 +1363,13 @@ MemorySpaceAssignment::Run(HloModule* module,
&memory_space_assignment.allocation_sequence_list_, options,
alias_analysis, hlo_live_range);
if (options.enable_cross_program_prefetch) {
absl::optional<BufferInterval> prefetch_candiate =
FindCrossProgramPrefetchCandidate(alias_analysis, hlo_live_range,
options);
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
}
HeapSimulator::Options heap_simulator_options;
heap_simulator_options.may_reuse_operand_buffers = false;
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,

View File

@ -319,6 +319,14 @@ class MemorySpaceAssignment {
// If true, verifies the memory space assignment against overlapping
// buffers.
bool verify = false;
// Enable prefetching buffers into preferred memory across program
// boundaries
bool enable_cross_program_prefetch = true;
// If true, use buffer_interval_compare to determine which buffers to
// prefetch across program boundaries.
bool default_cross_program_prefetch_heuristic = false;
};
// This class represents an allocation that might either be in the default or
@ -623,6 +631,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
}
}
// Allocates a buffer in preferred memory with whole program lifetime and
// enables prefetching prefech_candidate from default memory across program
// boundaries.
void AllocateCrossProgramPrefetchBuffer(
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
HeapSimulator::Result Finish() override;
private:

View File

@ -3028,5 +3028,203 @@ TEST_F(AsynchronousCopyOrderingTest, Simple) {
ordering.AddCopy({5, 14, alternate_mem_space});
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) {
HloComputation::Builder builder(TestName());
constexpr int kBatch = 8;
constexpr int kFeature = 8;
constexpr int kOutput = 2;
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto lhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
auto rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {param, lhs, rhs, dot});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 1);
if (!cross_program_prefetches.empty()) {
EXPECT_EQ(cross_program_prefetches[0].first, 0);
EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
}
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) {
HloComputation::Builder builder(TestName());
constexpr int kBatch = 8;
constexpr int kFeature = 8;
constexpr int kOutput = 2;
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0"));
auto gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(tuple_shape, param, 0));
auto lhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0));
auto rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {param, gte, lhs, rhs, dot});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) {
HloComputation::Builder builder(TestName());
constexpr int kFeature = 8;
constexpr int kOutput = 2;
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, rhs_shape, "p0"));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {param});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) {
HloComputation::Builder builder(TestName());
constexpr int kBatch = 8;
constexpr int kFeature = 8;
constexpr int kOutput = 8;
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
HloInstruction* param = builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto lhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
auto rhs = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
auto module = CreateNewVerifiedModule();
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {param, lhs, rhs, dot});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
HloComputation::Builder builder(TestName());
constexpr int kBatch = 2;
constexpr int kFeature = 2;
constexpr int kOutput = 2;
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
auto module = CreateNewVerifiedModule();
HloComputation::Builder fusion_builder("fusion");
{
HloInstruction* param = fusion_builder.AddInstruction(
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
auto lhs = fusion_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
auto rhs = fusion_builder.AddInstruction(
HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
DotDimensionNumbers dot_dnums;
dot_dnums.add_lhs_contracting_dimensions(1);
dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
(void)dot;
}
HloComputation* fusion_computation =
module->AddEmbeddedComputation(fusion_builder.Build());
auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
HloInstruction* tuple = builder.AddInstruction(
HloInstruction::CreateTuple({activations, weights}));
HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
result_shape, HloInstruction::FusionKind::kCustom, {tuple},
fusion_computation));
HloComputation* computation = module->AddEntryComputation(builder.Build());
HloSchedule schedule(module.get());
schedule.set_sequence(computation, {activations, weights, tuple, fusion});
TF_CHECK_OK(module->set_schedule(schedule));
AssignMemorySpace(module.get());
auto cross_program_prefetches = module->CrossProgramPrefetches();
EXPECT_EQ(cross_program_prefetches.size(), 0);
}
} // namespace
} // namespace xla