[XLA] Add an optional bool is_cross_program_prefetch field to kCopyStart HLOs.
This change to HLO is needed to disambiguate cross-program-prefetches and other prefetches performed over the same HloValue. This CL is in preparation for supporting freeing cross-program-prefetched buffers after their last use. PiperOrigin-RevId: 327881463 Change-Id: Id2ea6cd543589a7d49c689d44a2631a96ee9ddeb
This commit is contained in:
parent
2f29946403
commit
ac47af2254
@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 73
|
||||
// Next ID: 74
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -251,6 +251,9 @@ message HloInstructionProto {
|
||||
|
||||
// The comparison type used for kCompare.
|
||||
string comparison_type = 72;
|
||||
|
||||
// Specifies if this is a cross-program-prefetch, used by kCopyStart.
|
||||
bool is_cross_program_prefetch = 73;
|
||||
}
|
||||
|
||||
// Serialization of HloComputation.
|
||||
|
@ -1229,10 +1229,10 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
|
||||
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
|
||||
ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, constant));
|
||||
constant));
|
||||
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant->shape(), HloOpcode::kCopyDone, copy_start));
|
||||
module_->AddEntryComputation(builder.Build());
|
||||
|
@ -167,6 +167,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
absl::Span<const int64>(fft_length));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCopyStart: {
|
||||
instruction = CreateCopyStart(shape, operands(0),
|
||||
proto.is_cross_program_prefetch());
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCompare: {
|
||||
// Auto-upgraded from deprecated opcode skips the following.
|
||||
if (!comparison_direction) {
|
||||
@ -839,7 +844,6 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
case HloOpcode::kCeil:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCopyStart:
|
||||
case HloOpcode::kCopyDone:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kClz:
|
||||
@ -946,6 +950,13 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
|
||||
fft_length);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
bool is_cross_program_prefetch) {
|
||||
return absl::make_unique<HloCopyStartInstruction>(shape, operand,
|
||||
is_cross_program_prefetch);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
ComparisonDirection direction, absl::optional<Comparison::Type> type) {
|
||||
@ -4118,6 +4129,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
|
||||
return Cast<HloDomainInstruction>(this)->user_side_metadata();
|
||||
}
|
||||
|
||||
bool HloInstruction::is_cross_program_prefetch() const {
|
||||
return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
|
||||
}
|
||||
|
||||
ComparisonDirection HloInstruction::comparison_direction() const {
|
||||
return Cast<HloCompareInstruction>(this)->direction();
|
||||
}
|
||||
|
@ -592,6 +592,12 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* operand, FftType fft_type,
|
||||
absl::Span<const int64> fft_length);
|
||||
|
||||
// Creates a copy-start op, indicating whether this is a cross-program
|
||||
// prefetch or not.
|
||||
static std::unique_ptr<HloInstruction> CreateCopyStart(
|
||||
const Shape& shape, HloInstruction* operand,
|
||||
bool is_cross_program_prefetch = false);
|
||||
|
||||
// Creates a compare op, performing the comparison specified in direction.
|
||||
static std::unique_ptr<HloInstruction> CreateCompare(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
@ -1865,6 +1871,9 @@ class HloInstruction {
|
||||
// Delegates to HloDomainInstruction::user_side_metadata().
|
||||
const DomainMetadata& user_side_metadata() const;
|
||||
|
||||
// Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
|
||||
bool is_cross_program_prefetch() const;
|
||||
|
||||
// Delegates to HloCompareInstruction::direction().
|
||||
ComparisonDirection comparison_direction() const;
|
||||
|
||||
|
@ -204,6 +204,47 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
|
||||
fft_length_);
|
||||
}
|
||||
|
||||
HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
|
||||
HloInstruction* operand,
|
||||
bool is_cross_program_prefetch)
|
||||
: HloInstruction(HloOpcode::kCopyStart, shape),
|
||||
is_cross_program_prefetch_(is_cross_program_prefetch) {
|
||||
AppendOperand(operand);
|
||||
}
|
||||
|
||||
HloInstructionProto HloCopyStartInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
std::vector<string> result;
|
||||
if (is_cross_program_prefetch()) {
|
||||
result.push_back("is_cross_program_prefetch=true");
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
bool HloCopyStartInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
|
||||
return is_cross_program_prefetch() ==
|
||||
casted_other.is_cross_program_prefetch();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
HloCopyStartInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
return absl::make_unique<HloCopyStartInstruction>(
|
||||
shape, new_operands[0], is_cross_program_prefetch());
|
||||
}
|
||||
|
||||
HloCompareInstruction::HloCompareInstruction(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
ComparisonDirection direction, absl::optional<Comparison::Type> type)
|
||||
|
@ -132,6 +132,28 @@ class HloFftInstruction : public HloInstruction {
|
||||
std::vector<int64> fft_length_;
|
||||
};
|
||||
|
||||
class HloCopyStartInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
|
||||
bool is_cross_program_prefetch);
|
||||
|
||||
bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
bool is_cross_program_prefetch_;
|
||||
};
|
||||
|
||||
class HloCompareInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,
|
||||
|
@ -276,10 +276,10 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) {
|
||||
/*element_size_in_bits=*/0, /*memory_space=*/2);
|
||||
|
||||
auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0");
|
||||
auto copy_start = HloInstruction::CreateUnary(
|
||||
auto copy_start = HloInstruction::CreateCopyStart(
|
||||
ShapeUtil::MakeTupleShape(
|
||||
{shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, p0.get());
|
||||
p0.get());
|
||||
auto copy_done = HloInstruction::CreateUnary(
|
||||
shape_memspace2, HloOpcode::kCopyDone, copy_start.get());
|
||||
|
||||
|
@ -883,7 +883,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
case HloOpcode::kClz:
|
||||
case HloOpcode::kCollectivePermuteDone:
|
||||
case HloOpcode::kCopy:
|
||||
case HloOpcode::kCopyStart:
|
||||
case HloOpcode::kCopyDone:
|
||||
case HloOpcode::kCos:
|
||||
case HloOpcode::kExp:
|
||||
@ -1091,6 +1090,20 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
}
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCopyStart: {
|
||||
// If the is_cross_program_prefetch attribute is not present then default
|
||||
// to false.
|
||||
optional<bool> is_cross_program_prefetch = false;
|
||||
attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
|
||||
&is_cross_program_prefetch};
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
|
||||
shape, operands[0], *is_cross_program_prefetch));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kReplicaId: {
|
||||
if (!ParseOperands(&operands, /*expected_size=*/0) ||
|
||||
!ParseAttributes(attrs)) {
|
||||
|
@ -318,7 +318,7 @@ R"(HloModule CopyStartAndCopyDone_module
|
||||
|
||||
ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
|
||||
%v1 = f32[] parameter(0)
|
||||
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
|
||||
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true
|
||||
%copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
|
||||
%v2 = f32[2,3]{1,0:S(1)} parameter(1)
|
||||
%copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)
|
||||
|
@ -1409,7 +1409,8 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
|
||||
|
||||
AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
|
||||
chunk_candidate.chunk, prefetch_candidate->start,
|
||||
prefetch_candidate->end, latest_prefetch_time, &allocations);
|
||||
prefetch_candidate->end, latest_prefetch_time, &allocations,
|
||||
/*is_cross_program_prefetch=*/true);
|
||||
absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
|
||||
for (auto& allocation : allocations) {
|
||||
allocations_->push_back(std::move(allocation));
|
||||
@ -1887,7 +1888,8 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
||||
const MemorySpaceAssignment::Allocation& prev_allocation,
|
||||
MemorySpace memory_space, absl::optional<Chunk> chunk, int64 start_time,
|
||||
int64 end_time, int64 copy_done_schedule_before_time,
|
||||
MemorySpaceAssignment::AllocationSequence* allocations) {
|
||||
MemorySpaceAssignment::AllocationSequence* allocations,
|
||||
bool is_cross_program_prefetch) {
|
||||
VLOG(3) << "Copy to "
|
||||
<< (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
|
||||
? "default"
|
||||
@ -1899,7 +1901,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
|
||||
allocations->push_back(
|
||||
absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
|
||||
prev_allocation, memory_space, chunk, start_time, end_time,
|
||||
copy_done_schedule_before_time));
|
||||
copy_done_schedule_before_time, is_cross_program_prefetch));
|
||||
|
||||
// Register the additional async copy with the interval tree to keep track of
|
||||
// the limit at any given time.
|
||||
@ -2713,9 +2715,9 @@ Status MemorySpaceAssignment::CopyAllocation::Process(
|
||||
Shape shape = defining_position().shape();
|
||||
HloInstruction* producing_instruction = AddGetTupleElements();
|
||||
HloComputation* computation = producing_instruction->parent();
|
||||
copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
|
||||
copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
|
||||
ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, producing_instruction));
|
||||
producing_instruction, is_cross_program_prefetch_));
|
||||
copy_done_ = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
|
||||
VLOG(4) << "Created " << copy_start_->name()
|
||||
|
@ -581,12 +581,14 @@ class MemorySpaceAssignment {
|
||||
public:
|
||||
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
|
||||
absl::optional<Chunk> chunk, int64 start_time,
|
||||
int64 end_time, int64 copy_done_schedule_before_time)
|
||||
int64 end_time, int64 copy_done_schedule_before_time,
|
||||
bool is_cross_program_prefetch = false)
|
||||
: Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
|
||||
start_time, end_time),
|
||||
prev_allocation_(prev_allocation),
|
||||
copy_start_schedule_after_(start_time),
|
||||
copy_done_schedule_before_(copy_done_schedule_before_time) {}
|
||||
copy_done_schedule_before_(copy_done_schedule_before_time),
|
||||
is_cross_program_prefetch_(is_cross_program_prefetch) {}
|
||||
|
||||
bool is_copy_allocation() const override { return true; }
|
||||
|
||||
@ -626,6 +628,10 @@ class MemorySpaceAssignment {
|
||||
copy_start_schedule_after_ = copy_start_schedule_after;
|
||||
}
|
||||
|
||||
bool is_cross_program_prefetch() const {
|
||||
return is_cross_program_prefetch_;
|
||||
}
|
||||
|
||||
bool operator==(const CopyAllocation& other) const;
|
||||
std::string ToString() const override;
|
||||
|
||||
@ -637,6 +643,7 @@ class MemorySpaceAssignment {
|
||||
// is before copy_done_schedule_before_.
|
||||
int64 copy_start_schedule_after_;
|
||||
int64 copy_done_schedule_before_;
|
||||
bool is_cross_program_prefetch_;
|
||||
HloInstruction* copy_start_;
|
||||
HloInstruction* copy_done_;
|
||||
};
|
||||
@ -1208,7 +1215,8 @@ class AlternateMemoryBestFitHeap
|
||||
MemorySpace memory_space, absl::optional<Chunk> chunk,
|
||||
int64 start_time, int64 end_time,
|
||||
int64 copy_done_schedule_before_time,
|
||||
MemorySpaceAssignment::AllocationSequence* allocations);
|
||||
MemorySpaceAssignment::AllocationSequence* allocations,
|
||||
bool is_cross_program_prefetch = false);
|
||||
|
||||
// This method is used for committing the chunk candidate but adding it to
|
||||
// pending_chunks_ so that we can "uncommit" them in case we need to roll back
|
||||
|
@ -333,10 +333,10 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
|
||||
auto builder = HloComputation::Builder(TestName());
|
||||
auto constant = builder.AddInstruction(
|
||||
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
|
||||
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
|
||||
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
|
||||
ShapeUtil::MakeShape(U32, {})}),
|
||||
HloOpcode::kCopyStart, constant));
|
||||
constant));
|
||||
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
|
||||
constant->shape(), HloOpcode::kCopyDone, copy_start));
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user