Teach MemorySpaceAssignment to find buffers to keep in alternate memory across program scopes.
PiperOrigin-RevId: 302965410 Change-Id: Ie07d1c77add83740d56078b42b966ddd5a6c81d3
This commit is contained in:
parent
5943793102
commit
35a3591b3e
@ -358,6 +358,11 @@ message DynamicParameterBindingProto {
|
||||
repeated Binding entries = 1;
|
||||
}
|
||||
|
||||
message CrossProgramPrefetch {
|
||||
int64 parameter = 1;
|
||||
repeated int64 index = 2;
|
||||
}
|
||||
|
||||
// Serialization of HloModule.
|
||||
message HloModuleProto {
|
||||
string name = 1;
|
||||
@ -381,6 +386,8 @@ message HloModuleProto {
|
||||
HloInputOutputAliasProto input_output_alias = 8;
|
||||
|
||||
DynamicParameterBindingProto dynamic_parameter_binding = 9;
|
||||
|
||||
repeated CrossProgramPrefetch cross_program_prefetches = 10;
|
||||
}
|
||||
|
||||
// Serialization of LogicalBuffer.
|
||||
|
@ -263,6 +263,13 @@ HloModuleProto HloModule::ToProto() const {
|
||||
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
|
||||
*proto.mutable_dynamic_parameter_binding() =
|
||||
dynamic_parameter_binding().ToProto();
|
||||
for (auto [parameter, indices] : CrossProgramPrefetches()) {
|
||||
auto* prefetch = proto.mutable_cross_program_prefetches()->Add();
|
||||
prefetch->set_parameter(parameter);
|
||||
for (auto index : indices) {
|
||||
prefetch->add_index(index);
|
||||
}
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
@ -389,6 +396,12 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
TF_RETURN_IF_ERROR(module->set_schedule(std::move(schedule)));
|
||||
}
|
||||
|
||||
for (auto prefetch : proto.cross_program_prefetches()) {
|
||||
module->AddCrossProgramPrefetch(
|
||||
prefetch.parameter(),
|
||||
ShapeIndex(prefetch.index().begin(), prefetch.index().end()));
|
||||
}
|
||||
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
@ -669,6 +682,9 @@ std::unique_ptr<HloModule> HloModule::Clone(const HloModuleConfig& config,
|
||||
}
|
||||
TF_CHECK_OK(module->set_schedule(std::move(clone_schedule)));
|
||||
}
|
||||
for (auto [parameter, index] : CrossProgramPrefetches()) {
|
||||
module->AddCrossProgramPrefetch(parameter, index);
|
||||
}
|
||||
return module;
|
||||
}
|
||||
|
||||
|
@ -345,6 +345,17 @@ class HloModule {
|
||||
spmd_output_sharding_ = sharding;
|
||||
}
|
||||
|
||||
// Add a program argument to be prefetched across programs.
|
||||
void AddCrossProgramPrefetch(int64 parameter, const ShapeIndex& index) {
|
||||
cross_program_prefetches_.emplace_back(parameter, index);
|
||||
}
|
||||
|
||||
// Get the list of program arguments to be prefetch across programs.
|
||||
const absl::Span<const std::pair<int64, ShapeIndex>> CrossProgramPrefetches()
|
||||
const {
|
||||
return cross_program_prefetches_;
|
||||
}
|
||||
|
||||
private:
|
||||
HloComputation* AddComputationInternal(
|
||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||
@ -392,6 +403,9 @@ class HloModule {
|
||||
// The HLO sharding of the entry computation's output (root) for
|
||||
// SPMD-partitioned programs.
|
||||
absl::optional<HloSharding> spmd_output_sharding_;
|
||||
|
||||
// Arguments to be prefetched across programs.
|
||||
std::vector<std::pair<int64, ShapeIndex>> cross_program_prefetches_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
@ -355,6 +355,17 @@ HeapSimulator::Result AlternateMemoryBestFitHeap::Finish() {
|
||||
continue;
|
||||
}
|
||||
|
||||
HloInstruction* inst = interval.buffer->instruction();
|
||||
HloModule* module = inst->GetModule();
|
||||
|
||||
// Don't intra-program prefetch a cross program prefetch
|
||||
if (inst->opcode() == HloOpcode::kParameter &&
|
||||
absl::c_count(module->CrossProgramPrefetches(),
|
||||
std::make_pair(inst->parameter_number(),
|
||||
interval.buffer->index())) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto colocated_intervals = GetSortedColocatedIntervals(interval);
|
||||
|
||||
if (AreIntervalsReservedInAlternateMemory(colocated_intervals)) {
|
||||
@ -561,6 +572,52 @@ AlternateMemoryBestFitHeap::GetLiveAllocationAt(
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
|
||||
HloModule* module, absl::optional<BufferInterval> prefetch_candidate) {
|
||||
if (!prefetch_candidate) {
|
||||
return;
|
||||
}
|
||||
|
||||
ChunkCandidate chunk_candidate = FindChunkCandidate(*prefetch_candidate);
|
||||
if (chunk_candidate.chunk.offset != 0 ||
|
||||
chunk_candidate.heap_size > available_heap_size()) {
|
||||
LOG(WARNING)
|
||||
<< "Could not allocate preferred memory for cross program prefetch";
|
||||
return;
|
||||
}
|
||||
AddToPendingChunks(*prefetch_candidate, chunk_candidate);
|
||||
|
||||
const HloValue* buffer = prefetch_candidate->buffer;
|
||||
int64 parameter = buffer->instruction()->parameter_number();
|
||||
module->AddCrossProgramPrefetch(parameter, buffer->index());
|
||||
|
||||
allocation_sequence_list_->push_back({buffer, {}});
|
||||
MemorySpaceAssignment::AllocationSequence& allocations =
|
||||
allocation_sequence_list_->back().sequence;
|
||||
|
||||
allocations.push_back(absl::make_unique<MemorySpaceAssignment::Allocation>(
|
||||
buffer->defining_position(), MemorySpace::kDefault, kDummyChunk,
|
||||
prefetch_candidate->start, prefetch_candidate->end));
|
||||
|
||||
// Sort the uses by the use time.
|
||||
const auto& instruction_schedule = hlo_live_range_.instruction_schedule();
|
||||
auto uses = buffer->uses();
|
||||
auto first_use =
|
||||
absl::c_min_element(uses, [&](const HloUse& lhs, const HloUse& rhs) {
|
||||
return instruction_schedule.at(lhs.instruction) <
|
||||
instruction_schedule.at(rhs.instruction);
|
||||
});
|
||||
int64 latest_prefetch_time = instruction_schedule.at(first_use->instruction);
|
||||
|
||||
AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
|
||||
chunk_candidate.chunk, prefetch_candidate->start,
|
||||
prefetch_candidate->end, latest_prefetch_time, &allocations);
|
||||
absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
|
||||
|
||||
pending_chunks_.clear();
|
||||
pending_async_copies_.clear();
|
||||
}
|
||||
|
||||
void AlternateMemoryBestFitHeap::AddInputAndOutputRequiredAssignments() {
|
||||
// Go through the parameters and outputs and pin them to the corresponding
|
||||
// memory by adding a required assignment.
|
||||
@ -1207,6 +1264,90 @@ MemorySpaceAssignment::GetMemoryBoundednessBufferIntervalCompare(
|
||||
};
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
bool LooksLikeAnActivation(const HloInstruction* inst) {
|
||||
for (HloInstruction* user : inst->users()) {
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kDot:
|
||||
if (user->operand(0) == inst) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kGather:
|
||||
if (user->operand(1) == inst) {
|
||||
return true;
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kFusion:
|
||||
for (int i = 0; i < user->operand_count(); ++i) {
|
||||
if (user->operand(i) == inst &&
|
||||
LooksLikeAnActivation(user->fused_parameter(i))) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsCrossProgramPrefetchCandidate(
|
||||
const HloValue& value, const MemorySpaceAssignment::Options& options) {
|
||||
return value.instruction()->parent() ==
|
||||
value.instruction()->GetModule()->entry_computation() &&
|
||||
value.instruction()->opcode() == HloOpcode::kParameter &&
|
||||
value.index().size() == 1 && value.shape().IsArray() &&
|
||||
!value.uses().empty() &&
|
||||
options.size_fn(value) <= options.max_size_in_bytes &&
|
||||
absl::c_all_of(value.uses(), [&](const HloUse& use) {
|
||||
const HloInstruction* gte =
|
||||
use.instruction->operand(use.operand_number);
|
||||
return gte->opcode() == HloOpcode::kGetTupleElement &&
|
||||
!LooksLikeAnActivation(gte);
|
||||
});
|
||||
}
|
||||
|
||||
absl::optional<MemorySpaceAssignment::BufferInterval>
|
||||
FindCrossProgramPrefetchCandidate(
|
||||
const HloAliasAnalysis& alias_analysis, const HloLiveRange& hlo_live_range,
|
||||
const MemorySpaceAssignment::Options& options) {
|
||||
std::vector<MemorySpaceAssignment::BufferInterval> candidates;
|
||||
for (HloValue* value : alias_analysis.dataflow_analysis().values()) {
|
||||
if (IsCrossProgramPrefetchCandidate(*value, options)) {
|
||||
MemorySpaceAssignment::BufferInterval interval;
|
||||
interval.buffer = value;
|
||||
interval.size = options.size_fn(*value);
|
||||
interval.start = 0;
|
||||
interval.end = hlo_live_range.schedule_end_time();
|
||||
interval.need_allocation = true;
|
||||
candidates.emplace_back(interval);
|
||||
}
|
||||
}
|
||||
|
||||
// The buffer_interval_compare ought to do a good job picking the most
|
||||
// appropriate buffer to cross program prefetch, but empirically, it makes
|
||||
// worse choices than just picking the largest buffer.
|
||||
// TODO(b/152421603): Investigate.
|
||||
auto size_compare = [](const auto& x, const auto& y) {
|
||||
return x.size < y.size;
|
||||
};
|
||||
auto& compare = options.default_cross_program_prefetch_heuristic &&
|
||||
options.buffer_interval_compare
|
||||
? *options.buffer_interval_compare
|
||||
: size_compare;
|
||||
|
||||
auto best_candidate = absl::c_max_element(candidates, compare);
|
||||
if (best_candidate == candidates.end()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
return *best_candidate;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
/*static*/ StatusOr<std::unique_ptr<PresetAssignments>>
|
||||
MemorySpaceAssignment::Run(HloModule* module,
|
||||
const HloLiveRange& hlo_live_range,
|
||||
@ -1222,6 +1363,13 @@ MemorySpaceAssignment::Run(HloModule* module,
|
||||
&memory_space_assignment.allocation_sequence_list_, options,
|
||||
alias_analysis, hlo_live_range);
|
||||
|
||||
if (options.enable_cross_program_prefetch) {
|
||||
absl::optional<BufferInterval> prefetch_candiate =
|
||||
FindCrossProgramPrefetchCandidate(alias_analysis, hlo_live_range,
|
||||
options);
|
||||
algorithm->AllocateCrossProgramPrefetchBuffer(module, prefetch_candiate);
|
||||
}
|
||||
|
||||
HeapSimulator::Options heap_simulator_options;
|
||||
heap_simulator_options.may_reuse_operand_buffers = false;
|
||||
TF_RETURN_IF_ERROR(HeapSimulator::Run(std::move(algorithm), *module,
|
||||
|
@ -319,6 +319,14 @@ class MemorySpaceAssignment {
|
||||
// If true, verifies the memory space assignment against overlapping
|
||||
// buffers.
|
||||
bool verify = false;
|
||||
|
||||
// Enable prefetching buffers into preferred memory across program
|
||||
// boundaries
|
||||
bool enable_cross_program_prefetch = true;
|
||||
|
||||
// If true, use buffer_interval_compare to determine which buffers to
|
||||
// prefetch across program boundaries.
|
||||
bool default_cross_program_prefetch_heuristic = false;
|
||||
};
|
||||
|
||||
// This class represents an allocation that might either be in the default or
|
||||
@ -623,6 +631,12 @@ class AlternateMemoryBestFitHeap : public GlobalDecreasingSizeBestFitHeap {
|
||||
}
|
||||
}
|
||||
|
||||
// Allocates a buffer in preferred memory with whole program lifetime and
|
||||
// enables prefetching prefech_candidate from default memory across program
|
||||
// boundaries.
|
||||
void AllocateCrossProgramPrefetchBuffer(
|
||||
HloModule* module, absl::optional<BufferInterval> prefetch_candidate);
|
||||
|
||||
HeapSimulator::Result Finish() override;
|
||||
|
||||
private:
|
||||
|
@ -3028,5 +3028,203 @@ TEST_F(AsynchronousCopyOrderingTest, Simple) {
|
||||
ordering.AddCopy({5, 14, alternate_mem_space});
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTest) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
constexpr int kBatch = 8;
|
||||
constexpr int kFeature = 8;
|
||||
constexpr int kOutput = 2;
|
||||
|
||||
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
|
||||
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
|
||||
auto lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
|
||||
auto rhs = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
|
||||
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(computation, {param, lhs, rhs, dot});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
auto cross_program_prefetches = module->CrossProgramPrefetches();
|
||||
EXPECT_EQ(cross_program_prefetches.size(), 1);
|
||||
if (!cross_program_prefetches.empty()) {
|
||||
EXPECT_EQ(cross_program_prefetches[0].first, 0);
|
||||
EXPECT_EQ(cross_program_prefetches[0].second, ShapeIndex({1}));
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchNestedTupleTest) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
constexpr int kBatch = 8;
|
||||
constexpr int kFeature = 8;
|
||||
constexpr int kOutput = 2;
|
||||
|
||||
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
|
||||
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
|
||||
auto tuple_tuple_shape = ShapeUtil::MakeTupleShape({tuple_shape});
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_tuple_shape, "p0"));
|
||||
|
||||
auto gte = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(tuple_shape, param, 0));
|
||||
|
||||
auto lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(lhs_shape, gte, 0));
|
||||
auto rhs = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(rhs_shape, gte, 1));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
|
||||
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(computation, {param, gte, lhs, rhs, dot});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
auto cross_program_prefetches = module->CrossProgramPrefetches();
|
||||
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchUnusedParamTest) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
constexpr int kFeature = 8;
|
||||
constexpr int kOutput = 2;
|
||||
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, rhs_shape, "p0"));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(computation, {param});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
auto cross_program_prefetches = module->CrossProgramPrefetches();
|
||||
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchTooBigTest) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
constexpr int kBatch = 8;
|
||||
constexpr int kFeature = 8;
|
||||
constexpr int kOutput = 8;
|
||||
|
||||
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
|
||||
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
|
||||
HloInstruction* param = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
|
||||
auto lhs = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
|
||||
auto rhs = builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
|
||||
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto dot = builder.AddInstruction(HloInstruction::CreateDot(
|
||||
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(computation, {param, lhs, rhs, dot});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
auto cross_program_prefetches = module->CrossProgramPrefetches();
|
||||
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
||||
}
|
||||
|
||||
TEST_P(MemorySpaceAssignmentTest, CrossProgramPrefetchFusionTest) {
|
||||
HloComputation::Builder builder(TestName());
|
||||
|
||||
constexpr int kBatch = 2;
|
||||
constexpr int kFeature = 2;
|
||||
constexpr int kOutput = 2;
|
||||
|
||||
auto lhs_shape = ShapeUtil::MakeShape(F32, {kBatch, kFeature});
|
||||
auto rhs_shape = ShapeUtil::MakeShape(F32, {kFeature, kOutput});
|
||||
auto result_shape = ShapeUtil::MakeShape(F32, {kBatch, kOutput});
|
||||
auto tuple_shape = ShapeUtil::MakeTupleShape({lhs_shape, rhs_shape});
|
||||
|
||||
auto module = CreateNewVerifiedModule();
|
||||
HloComputation::Builder fusion_builder("fusion");
|
||||
{
|
||||
HloInstruction* param = fusion_builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, tuple_shape, "p0"));
|
||||
auto lhs = fusion_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(lhs_shape, param, 0));
|
||||
auto rhs = fusion_builder.AddInstruction(
|
||||
HloInstruction::CreateGetTupleElement(rhs_shape, param, 1));
|
||||
DotDimensionNumbers dot_dnums;
|
||||
dot_dnums.add_lhs_contracting_dimensions(1);
|
||||
dot_dnums.add_rhs_contracting_dimensions(0);
|
||||
auto dot = fusion_builder.AddInstruction(HloInstruction::CreateDot(
|
||||
result_shape, lhs, rhs, dot_dnums, DefaultPrecisionConfig(2)));
|
||||
(void)dot;
|
||||
}
|
||||
HloComputation* fusion_computation =
|
||||
module->AddEmbeddedComputation(fusion_builder.Build());
|
||||
|
||||
auto activations = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
|
||||
auto weights = builder.AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR2<float>({{0.0, 1.0}, {2.0, 3.0}})));
|
||||
HloInstruction* tuple = builder.AddInstruction(
|
||||
HloInstruction::CreateTuple({activations, weights}));
|
||||
HloInstruction* fusion = builder.AddInstruction(HloInstruction::CreateFusion(
|
||||
result_shape, HloInstruction::FusionKind::kCustom, {tuple},
|
||||
fusion_computation));
|
||||
|
||||
HloComputation* computation = module->AddEntryComputation(builder.Build());
|
||||
|
||||
HloSchedule schedule(module.get());
|
||||
schedule.set_sequence(computation, {activations, weights, tuple, fusion});
|
||||
TF_CHECK_OK(module->set_schedule(schedule));
|
||||
|
||||
AssignMemorySpace(module.get());
|
||||
|
||||
auto cross_program_prefetches = module->CrossProgramPrefetches();
|
||||
EXPECT_EQ(cross_program_prefetches.size(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user