[XLA] Add an optional bool is_cross_program_prefetch field to kCopyStart HLOs.

This change to HLO is needed to disambiguate cross-program-prefetches and other
prefetches performed over the same HloValue. This CL is in preparation for
supporting freeing cross-program-prefetched buffers after their last use.

PiperOrigin-RevId: 327881463
Change-Id: Id2ea6cd543589a7d49c689d44a2631a96ee9ddeb
This commit is contained in:
Berkin Ilbeyi 2020-08-21 15:05:44 -07:00 committed by TensorFlower Gardener
parent 2f29946403
commit ac47af2254
12 changed files with 131 additions and 18 deletions

View File

@ -35,7 +35,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 73
// Next ID: 74
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -251,6 +251,9 @@ message HloInstructionProto {
// The comparison type used for kCompare.
string comparison_type = 72;
// Specifies if this is a cross-program-prefetch, used by kCopyStart.
bool is_cross_program_prefetch = 73;
}
// Serialization of HloComputation.

View File

@ -1229,10 +1229,10 @@ TEST_P(HloDataflowAnalysisTest, CopyStartAndCopyDone) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, constant));
constant));
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopyDone, copy_start));
module_->AddEntryComputation(builder.Build());

View File

@ -167,6 +167,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
absl::Span<const int64>(fft_length));
break;
}
case HloOpcode::kCopyStart: {
instruction = CreateCopyStart(shape, operands(0),
proto.is_cross_program_prefetch());
break;
}
case HloOpcode::kCompare: {
// Auto-upgraded from deprecated opcode skips the following.
if (!comparison_direction) {
@ -839,7 +844,6 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
case HloOpcode::kCeil:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kClz:
@ -946,6 +950,13 @@ HloInstruction::CreateRngBitGenerator(const Shape& shape, HloInstruction* state,
fft_length);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCopyStart(
const Shape& shape, HloInstruction* operand,
bool is_cross_program_prefetch) {
return absl::make_unique<HloCopyStartInstruction>(shape, operand,
is_cross_program_prefetch);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction, absl::optional<Comparison::Type> type) {
@ -4118,6 +4129,10 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
return Cast<HloDomainInstruction>(this)->user_side_metadata();
}
bool HloInstruction::is_cross_program_prefetch() const {
return Cast<HloCopyStartInstruction>(this)->is_cross_program_prefetch();
}
ComparisonDirection HloInstruction::comparison_direction() const {
return Cast<HloCompareInstruction>(this)->direction();
}

View File

@ -592,6 +592,12 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, FftType fft_type,
absl::Span<const int64> fft_length);
// Creates a copy-start op, indicating whether this is a cross-program
// prefetch or not.
static std::unique_ptr<HloInstruction> CreateCopyStart(
const Shape& shape, HloInstruction* operand,
bool is_cross_program_prefetch = false);
// Creates a compare op, performing the comparison specified in direction.
static std::unique_ptr<HloInstruction> CreateCompare(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
@ -1865,6 +1871,9 @@ class HloInstruction {
// Delegates to HloDomainInstruction::user_side_metadata().
const DomainMetadata& user_side_metadata() const;
// Delegates to HloCopyStartInstruction::is_cross_program_prefetch().
bool is_cross_program_prefetch() const;
// Delegates to HloCompareInstruction::direction().
ComparisonDirection comparison_direction() const;

View File

@ -204,6 +204,47 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
fft_length_);
}
HloCopyStartInstruction::HloCopyStartInstruction(const Shape& shape,
HloInstruction* operand,
bool is_cross_program_prefetch)
: HloInstruction(HloOpcode::kCopyStart, shape),
is_cross_program_prefetch_(is_cross_program_prefetch) {
AppendOperand(operand);
}
HloInstructionProto HloCopyStartInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_is_cross_program_prefetch(is_cross_program_prefetch_);
return proto;
}
std::vector<string> HloCopyStartInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
std::vector<string> result;
if (is_cross_program_prefetch()) {
result.push_back("is_cross_program_prefetch=true");
}
return result;
}
bool HloCopyStartInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloCopyStartInstruction&>(other);
return is_cross_program_prefetch() ==
casted_other.is_cross_program_prefetch();
}
std::unique_ptr<HloInstruction>
HloCopyStartInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloCopyStartInstruction>(
shape, new_operands[0], is_cross_program_prefetch());
}
HloCompareInstruction::HloCompareInstruction(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
ComparisonDirection direction, absl::optional<Comparison::Type> type)

View File

@ -132,6 +132,28 @@ class HloFftInstruction : public HloInstruction {
std::vector<int64> fft_length_;
};
class HloCopyStartInstruction : public HloInstruction {
public:
explicit HloCopyStartInstruction(const Shape& shape, HloInstruction* operand,
bool is_cross_program_prefetch);
bool is_cross_program_prefetch() const { return is_cross_program_prefetch_; }
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
bool is_cross_program_prefetch_;
};
class HloCompareInstruction : public HloInstruction {
public:
explicit HloCompareInstruction(const Shape& shape, HloInstruction* lhs,

View File

@ -276,10 +276,10 @@ TEST_F(HloMatchersTest, AsyncCopyMatcher) {
/*element_size_in_bits=*/0, /*memory_space=*/2);
auto p0 = HloInstruction::CreateParameter(0, shape_memspace1, "p0");
auto copy_start = HloInstruction::CreateUnary(
auto copy_start = HloInstruction::CreateCopyStart(
ShapeUtil::MakeTupleShape(
{shape_memspace2, shape_memspace1, ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, p0.get());
p0.get());
auto copy_done = HloInstruction::CreateUnary(
shape_memspace2, HloOpcode::kCopyDone, copy_start.get());

View File

@ -883,7 +883,6 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
case HloOpcode::kClz:
case HloOpcode::kCollectivePermuteDone:
case HloOpcode::kCopy:
case HloOpcode::kCopyStart:
case HloOpcode::kCopyDone:
case HloOpcode::kCos:
case HloOpcode::kExp:
@ -1091,6 +1090,20 @@ bool HloParserImpl::ParseInstructionRhs(HloComputation::Builder* builder,
}
break;
}
case HloOpcode::kCopyStart: {
// If the is_cross_program_prefetch attribute is not present then default
// to false.
optional<bool> is_cross_program_prefetch = false;
attrs["is_cross_program_prefetch"] = {/*required=*/false, AttrTy::kBool,
&is_cross_program_prefetch};
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateCopyStart(
shape, operands[0], *is_cross_program_prefetch));
break;
}
case HloOpcode::kReplicaId: {
if (!ParseOperands(&operands, /*expected_size=*/0) ||
!ParseAttributes(attrs)) {

View File

@ -318,7 +318,7 @@ R"(HloModule CopyStartAndCopyDone_module
ENTRY %CopyStartAndCopyDone (v1: f32[], v2: f32[2,3]) -> (f32[], f32[2,3]) {
%v1 = f32[] parameter(0)
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1)
%copy-start.1 = (f32[], f32[], u32[]) copy-start(f32[] %v1), is_cross_program_prefetch=true
%copy-done.1 = f32[] copy-done((f32[], f32[], u32[]) %copy-start.1)
%v2 = f32[2,3]{1,0:S(1)} parameter(1)
%copy-start.2 = (f32[2,3]{1,0:S(2)}, f32[2,3]{1,0:S(1)}, u32[]) copy-start(f32[2,3]{1,0:S(1)} %v2)

View File

@ -1409,7 +1409,8 @@ void AlternateMemoryBestFitHeap::AllocateCrossProgramPrefetchBuffer(
AddAsyncCopy(*allocations.back(), MemorySpace::kAlternate,
chunk_candidate.chunk, prefetch_candidate->start,
prefetch_candidate->end, latest_prefetch_time, &allocations);
prefetch_candidate->end, latest_prefetch_time, &allocations,
/*is_cross_program_prefetch=*/true);
absl::c_for_each(uses, [&](auto& use) { allocations.back()->AddUse(use); });
for (auto& allocation : allocations) {
allocations_->push_back(std::move(allocation));
@ -1887,7 +1888,8 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
const MemorySpaceAssignment::Allocation& prev_allocation,
MemorySpace memory_space, absl::optional<Chunk> chunk, int64 start_time,
int64 end_time, int64 copy_done_schedule_before_time,
MemorySpaceAssignment::AllocationSequence* allocations) {
MemorySpaceAssignment::AllocationSequence* allocations,
bool is_cross_program_prefetch) {
VLOG(3) << "Copy to "
<< (memory_space == MemorySpaceAssignment::MemorySpace::kDefault
? "default"
@ -1899,7 +1901,7 @@ void AlternateMemoryBestFitHeap::AddAsyncCopy(
allocations->push_back(
absl::make_unique<MemorySpaceAssignment::CopyAllocation>(
prev_allocation, memory_space, chunk, start_time, end_time,
copy_done_schedule_before_time));
copy_done_schedule_before_time, is_cross_program_prefetch));
// Register the additional async copy with the interval tree to keep track of
// the limit at any given time.
@ -2713,9 +2715,9 @@ Status MemorySpaceAssignment::CopyAllocation::Process(
Shape shape = defining_position().shape();
HloInstruction* producing_instruction = AddGetTupleElements();
HloComputation* computation = producing_instruction->parent();
copy_start_ = computation->AddInstruction(HloInstruction::CreateUnary(
copy_start_ = computation->AddInstruction(HloInstruction::CreateCopyStart(
ShapeUtil::MakeTupleShape({shape, shape, ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, producing_instruction));
producing_instruction, is_cross_program_prefetch_));
copy_done_ = computation->AddInstruction(
HloInstruction::CreateUnary(shape, HloOpcode::kCopyDone, copy_start_));
VLOG(4) << "Created " << copy_start_->name()

View File

@ -581,12 +581,14 @@ class MemorySpaceAssignment {
public:
CopyAllocation(const Allocation& prev_allocation, MemorySpace memory_space,
absl::optional<Chunk> chunk, int64 start_time,
int64 end_time, int64 copy_done_schedule_before_time)
int64 end_time, int64 copy_done_schedule_before_time,
bool is_cross_program_prefetch = false)
: Allocation(/*defining_position=*/{nullptr, {}}, memory_space, chunk,
start_time, end_time),
prev_allocation_(prev_allocation),
copy_start_schedule_after_(start_time),
copy_done_schedule_before_(copy_done_schedule_before_time) {}
copy_done_schedule_before_(copy_done_schedule_before_time),
is_cross_program_prefetch_(is_cross_program_prefetch) {}
bool is_copy_allocation() const override { return true; }
@ -626,6 +628,10 @@ class MemorySpaceAssignment {
copy_start_schedule_after_ = copy_start_schedule_after;
}
bool is_cross_program_prefetch() const {
return is_cross_program_prefetch_;
}
bool operator==(const CopyAllocation& other) const;
std::string ToString() const override;
@ -637,6 +643,7 @@ class MemorySpaceAssignment {
// is before copy_done_schedule_before_.
int64 copy_start_schedule_after_;
int64 copy_done_schedule_before_;
bool is_cross_program_prefetch_;
HloInstruction* copy_start_;
HloInstruction* copy_done_;
};
@ -1208,7 +1215,8 @@ class AlternateMemoryBestFitHeap
MemorySpace memory_space, absl::optional<Chunk> chunk,
int64 start_time, int64 end_time,
int64 copy_done_schedule_before_time,
MemorySpaceAssignment::AllocationSequence* allocations);
MemorySpaceAssignment::AllocationSequence* allocations,
bool is_cross_program_prefetch = false);
// This method is used for committing the chunk candidate but adding it to
// pending_chunks_ so that we can "uncommit" them in case we need to roll back

View File

@ -333,10 +333,10 @@ TEST_F(TuplePointsToAnalysisTest, CopyStartAndCopyDone) {
auto builder = HloComputation::Builder(TestName());
auto constant = builder.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
auto copy_start = builder.AddInstruction(HloInstruction::CreateUnary(
auto copy_start = builder.AddInstruction(HloInstruction::CreateCopyStart(
ShapeUtil::MakeTupleShape({constant->shape(), constant->shape(),
ShapeUtil::MakeShape(U32, {})}),
HloOpcode::kCopyStart, constant));
constant));
auto copy_done = builder.AddInstruction(HloInstruction::CreateUnary(
constant->shape(), HloOpcode::kCopyDone, copy_start));