Compressing rematerialization
Adds a new kind of rematerialization that compresses the node into a compact form, uncompresses it back at a later program point. PiperOrigin-RevId: 264734096
This commit is contained in:
parent
e9a9801734
commit
02f3c946f4
@ -100,6 +100,17 @@ bool CanBeRematerialized(
|
||||
using BufferId = int64;
|
||||
using BufferIdList = absl::InlinedVector<BufferId, 3>;
|
||||
|
||||
struct RematStrategy {
|
||||
enum {
|
||||
// Recompute the node at a later program point.
|
||||
kRecompute,
|
||||
// Change the layout into a compact form and uncompress it back at a later
|
||||
// program point.
|
||||
kCompress,
|
||||
} kind;
|
||||
Shape compact_shape;
|
||||
};
|
||||
|
||||
// We wrap HloInstruction* with an Item that holds auxiliary
|
||||
// per-instruction state.
|
||||
struct Item {
|
||||
@ -117,6 +128,10 @@ struct Item {
|
||||
// The buffers defined by this instruction.
|
||||
BufferIdList buffers_defined;
|
||||
|
||||
// Output buffers of this instruction. This is used to track outputs by GTE
|
||||
// instructions (where the instruction doesn't define a buffer).
|
||||
BufferIdList buffers_output;
|
||||
|
||||
// The buffers used by this instruction.
|
||||
BufferIdList buffers_used;
|
||||
|
||||
@ -251,6 +266,34 @@ class InstructionList {
|
||||
return InsertBefore(to_insert, min_position_item);
|
||||
}
|
||||
|
||||
void InsertAfterInstructions(Item* to_insert,
|
||||
absl::Span<Item* const> after_instructions) {
|
||||
VLOG(3) << "InsertAfterInstructions: " << to_insert->instruction->name()
|
||||
<< " after {"
|
||||
<< absl::StrJoin(after_instructions, ", ",
|
||||
[](string* out, Item* item) {
|
||||
absl::StrAppend(out, item->instruction->name());
|
||||
})
|
||||
<< "}";
|
||||
|
||||
// Find the max position number of any instruction in
|
||||
// 'after_instructions'.
|
||||
CHECK(!after_instructions.empty());
|
||||
Item* max_position_item = nullptr;
|
||||
for (Item* item : after_instructions) {
|
||||
if (max_position_item == nullptr ||
|
||||
item->position > max_position_item->position) {
|
||||
max_position_item = item;
|
||||
}
|
||||
}
|
||||
if (max_position_item->next == nullptr) {
|
||||
InsertAfter(to_insert, max_position_item);
|
||||
|
||||
} else {
|
||||
InsertBeforeInstructions(to_insert, {max_position_item->next});
|
||||
}
|
||||
}
|
||||
|
||||
void Blacklist(const HloInstruction* inst) {
|
||||
GetItem(inst)->blacklisted = true;
|
||||
}
|
||||
@ -276,6 +319,24 @@ class InstructionList {
|
||||
item->position = before->position;
|
||||
}
|
||||
|
||||
void InsertAfter(Item* item, Item* after) {
|
||||
VLOG(3) << "InsertAfter: " << item->instruction->name() << " after "
|
||||
<< after->instruction->name();
|
||||
// Insert new item into linked list.
|
||||
item->next = after->next;
|
||||
item->prev = after;
|
||||
|
||||
after->next = item;
|
||||
if (item->next != nullptr) {
|
||||
item->next->prev = item;
|
||||
}
|
||||
|
||||
// Assign the same position number to the newly added instruction as
|
||||
// 'before'. This guarantees monotonicity of the position numbers, but not
|
||||
// uniqueness.
|
||||
item->position = after->position;
|
||||
}
|
||||
|
||||
Item* first_;
|
||||
|
||||
// Item for each instruction.
|
||||
@ -327,6 +388,7 @@ class MemoryUsageTracker {
|
||||
MemoryUsageTracker(
|
||||
const HloComputation* computation,
|
||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||
const HloRematerialization::CompactShapeFunction& compact_shape_function,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const InstructionList& instruction_list);
|
||||
|
||||
@ -338,6 +400,22 @@ class MemoryUsageTracker {
|
||||
// EndInstruction memory for dead operand(s) is freed.
|
||||
Status BeginInstruction(Item* item);
|
||||
|
||||
int64 RematerializationCost(const HloInstruction* instruction,
|
||||
int64 memory_reduced, int64 memory_limit_bytes) {
|
||||
// If none of the users of 'instruction' have been placed in the sequence
|
||||
// (as tracked by memory_tracker), then rematerialization of 'instruction'
|
||||
// is a zero-cost move of 'instruction' in the sequence.
|
||||
if (!absl::c_any_of(
|
||||
instruction->users(),
|
||||
[this](const HloInstruction* inst) { return IsPlaced(inst); })) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
CHECK_GT(memory_reduced, 0);
|
||||
// Return the inverse of the benefit of rematerialization.
|
||||
return memory_limit_bytes / memory_reduced;
|
||||
}
|
||||
|
||||
// Finishes the placement of the current instruction. This frees any dead
|
||||
// operands or dead result of the instruction. This must be called after
|
||||
// each call to BeginInstruction.
|
||||
@ -347,17 +425,28 @@ class MemoryUsageTracker {
|
||||
// if the given instruction is rematerialized.
|
||||
int64 MemoryReducedIfRematerialized(Item* item) const;
|
||||
|
||||
// Returns the number of bytes that the current memory usage will be reduced
|
||||
// if the given instruction is compact.
|
||||
int64 MemoryReducedIfCompressed(Item* item, const Shape& compact_shape) const;
|
||||
|
||||
// Returns the number of bytes that the current memory usage will be reduced
|
||||
// by if the given sequence of instructions is rematerialized.
|
||||
int64 MemoryReducedIfRematerialized(const absl::Span<Item*>& items) const;
|
||||
|
||||
Status AddCompressInstructions(Item* original_item, Item* compressed_item,
|
||||
Item* uncompressed_item);
|
||||
|
||||
// Adjusts memory usage to account for the rematerialization of
|
||||
// original_item for all remaining unplaced uses. The rematerialization
|
||||
// is remat_item. This method should be called after the HLO graph has
|
||||
// been transformed (rematerialization instruction created and connected to
|
||||
// uses).
|
||||
// been transformed (rematerialization instruction created and connected
|
||||
// to uses).
|
||||
Status AddRematerializedInstruction(Item* original_item, Item* remat_item);
|
||||
|
||||
std::pair<Item*, RematStrategy> PickRematerializationCandidate(
|
||||
const InstructionList& instruction_list, int64 memory_limit_bytes,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* remat_able);
|
||||
|
||||
// Returns whether the given instruction has been placed (BeginInstruction
|
||||
// has been called with 'instruction' as the argument).
|
||||
bool IsPlaced(const HloInstruction* instruction) const {
|
||||
@ -390,6 +479,9 @@ class MemoryUsageTracker {
|
||||
// The materialized size of the buffer in bytes.
|
||||
const int64 size;
|
||||
|
||||
// Shape of the buffer.
|
||||
Shape shape;
|
||||
|
||||
// Whether this buffer is live-out of the computation.
|
||||
bool live_out;
|
||||
|
||||
@ -412,19 +504,21 @@ class MemoryUsageTracker {
|
||||
}
|
||||
};
|
||||
|
||||
// Get the compact shape of given hlo instruction. An internal cache is used
|
||||
// to avoid computing the shape multiple times.
|
||||
StatusOr<Shape> GetCompactShape(const HloInstruction* hlo);
|
||||
|
||||
// Creates a Buffer representing the given logical buffer. The buffer is added
|
||||
// to buffers_ and a reference is returned.
|
||||
Buffer& CreateBufferFromLogicalBuffer(
|
||||
const LogicalBuffer* logical_buffer,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||
bool live_out) {
|
||||
const TuplePointsToAnalysis& points_to_analysis, bool live_out) {
|
||||
bool has_indirect_uses = false;
|
||||
ItemList users = GetUsers(instruction_list_, logical_buffer,
|
||||
points_to_analysis, &has_indirect_uses);
|
||||
return NewBuffer(instruction_list_.GetItem(logical_buffer->instruction()),
|
||||
size_function(logical_buffer->shape()), std::move(users),
|
||||
live_out, has_indirect_uses);
|
||||
logical_buffer->shape(), std::move(users), live_out,
|
||||
has_indirect_uses);
|
||||
}
|
||||
|
||||
// Create a new buffer representing a rematerialization of given buffer for
|
||||
@ -438,7 +532,7 @@ class MemoryUsageTracker {
|
||||
for (Item* use : rematerialized_uses) {
|
||||
CHECK(!use->placed) << use->instruction->name();
|
||||
}
|
||||
return NewBuffer(remat_item, original_buffer.size,
|
||||
return NewBuffer(remat_item, original_buffer.shape,
|
||||
std::move(rematerialized_uses), /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
}
|
||||
@ -449,7 +543,8 @@ class MemoryUsageTracker {
|
||||
// different computation.
|
||||
int64 AllocatedSize(BufferId buffer_id) const {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
HloOpcode def_opcode = buffer.defining_instruction->instruction->opcode();
|
||||
HloInstruction* inst = buffer.defining_instruction->instruction;
|
||||
HloOpcode def_opcode = inst->opcode();
|
||||
if (buffer.live_out || def_opcode == HloOpcode::kParameter) {
|
||||
return 0;
|
||||
} else {
|
||||
@ -482,12 +577,12 @@ class MemoryUsageTracker {
|
||||
}
|
||||
|
||||
// Create a new buffer, add it to buffers_, and return a reference.
|
||||
Buffer& NewBuffer(Item* defining_instruction, int64 size, ItemList&& users,
|
||||
bool live_out, bool has_indirect_uses) {
|
||||
Buffer& NewBuffer(Item* defining_instruction, const Shape& shape,
|
||||
ItemList&& users, bool live_out, bool has_indirect_uses) {
|
||||
int buffer_id = buffers_.size();
|
||||
buffers_.push_back(Buffer{buffer_id, defining_instruction, size, live_out,
|
||||
has_indirect_uses, users,
|
||||
static_cast<int64>(users.size())});
|
||||
buffers_.push_back(Buffer{
|
||||
buffer_id, defining_instruction, size_function_(shape), shape, live_out,
|
||||
has_indirect_uses, users, static_cast<int64>(users.size())});
|
||||
return buffers_.back();
|
||||
}
|
||||
|
||||
@ -498,6 +593,16 @@ class MemoryUsageTracker {
|
||||
// (BeginInstruction/EndInstruction calls).
|
||||
const InstructionList& instruction_list_;
|
||||
|
||||
// Size function returns the bytes of a given buffer.
|
||||
const HloRematerialization::ShapeSizeFunction& size_function_;
|
||||
|
||||
// Converts a shape into compact form, returns the same shape if a shape is
|
||||
// already considered compact.
|
||||
const HloRematerialization::CompactShapeFunction& compact_shape_function_;
|
||||
|
||||
// A map that caches existing known compact shape for each instruction.
|
||||
absl::flat_hash_map<const HloInstruction*, Shape> compact_shape_;
|
||||
|
||||
// Memory usage at the currently placed instruction.
|
||||
int64 memory_usage_ = 0;
|
||||
|
||||
@ -512,9 +617,13 @@ class MemoryUsageTracker {
|
||||
MemoryUsageTracker::MemoryUsageTracker(
|
||||
const HloComputation* computation,
|
||||
const HloRematerialization::ShapeSizeFunction& size_function,
|
||||
const HloRematerialization::CompactShapeFunction& compact_shape_function,
|
||||
const TuplePointsToAnalysis& points_to_analysis,
|
||||
const InstructionList& instruction_list)
|
||||
: computation_(computation), instruction_list_(instruction_list) {
|
||||
: computation_(computation),
|
||||
instruction_list_(instruction_list),
|
||||
size_function_(size_function),
|
||||
compact_shape_function_(compact_shape_function) {
|
||||
PointsToSet::BufferSet live_out_set =
|
||||
points_to_analysis.GetPointsToSet(computation_->root_instruction())
|
||||
.CreateFlattenedSet();
|
||||
@ -556,7 +665,7 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
}
|
||||
} else {
|
||||
buffer = &CreateBufferFromLogicalBuffer(
|
||||
logical_buffer, points_to_analysis, size_function,
|
||||
logical_buffer, points_to_analysis,
|
||||
ContainsKey(live_out_set, logical_buffer));
|
||||
item->buffers_defined.push_back(buffer->id);
|
||||
for (Item* user : buffer->users) {
|
||||
@ -566,6 +675,14 @@ MemoryUsageTracker::MemoryUsageTracker(
|
||||
|
||||
logical_buffer_to_buffer_id[logical_buffer] = buffer->id;
|
||||
}
|
||||
|
||||
// Trace the output of each instruction. This is so that we can properly
|
||||
// track which outputs does GTEs have.
|
||||
for (const LogicalBuffer* logical_buffer :
|
||||
points_to_analysis.GetPointsToSet(instruction).CreateFlattenedSet()) {
|
||||
item->buffers_output.push_back(
|
||||
logical_buffer_to_buffer_id[logical_buffer]);
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(10, ToString());
|
||||
DCHECK(Check());
|
||||
@ -637,6 +754,29 @@ Status MemoryUsageTracker::EndInstruction() {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 MemoryUsageTracker::MemoryReducedIfCompressed(
|
||||
Item* item, const Shape& compact_shape) const {
|
||||
CHECK_NE(in_progress_item_, nullptr);
|
||||
if (!item->placed || item == in_progress_item_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
int64 memory_reduced = 0;
|
||||
|
||||
// We only compress a single piece of an output at one time.
|
||||
CHECK_EQ(item->buffers_output.size(), 1);
|
||||
BufferId buffer_id = item->buffers_output[0];
|
||||
if (IsCurrentlyLive(buffer_id) && !IsInUse(buffer_id)) {
|
||||
const Buffer& buffer = buffers_.at(buffer_id);
|
||||
memory_reduced += buffer.size;
|
||||
|
||||
int64 compact_shape_size = size_function_(compact_shape);
|
||||
// Account for buffers that are compress after instruction.
|
||||
memory_reduced -= compact_shape_size;
|
||||
}
|
||||
return memory_reduced;
|
||||
}
|
||||
|
||||
int64 MemoryUsageTracker::MemoryReducedIfRematerialized(Item* item) const {
|
||||
CHECK_NE(in_progress_item_, nullptr);
|
||||
if (!item->placed || item == in_progress_item_) {
|
||||
@ -736,6 +876,56 @@ int64 MemoryUsageTracker::MemoryReducedIfRematerialized(
|
||||
return memory_reduced;
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::AddCompressInstructions(Item* original_item,
|
||||
Item* compressed_item,
|
||||
Item* uncompressed_item) {
|
||||
// Original buffer is now dead.
|
||||
memory_usage_ -= size_function_(original_item->instruction->shape());
|
||||
// Compressed buffer is now alive.
|
||||
memory_usage_ += size_function_(compressed_item->instruction->shape());
|
||||
|
||||
ItemList placed_users;
|
||||
ItemList unplaced_users;
|
||||
CHECK_EQ(original_item->buffers_output.size(), 1);
|
||||
BufferId original_buffer_id = original_item->buffers_output[0];
|
||||
Buffer& original_buffer = buffers_.at(original_buffer_id);
|
||||
for (Item* user : original_buffer.users) {
|
||||
if (user->placed) {
|
||||
CHECK(IsFinished(user)) << user->instruction->name();
|
||||
placed_users.push_back(user);
|
||||
} else {
|
||||
unplaced_users.push_back(user);
|
||||
}
|
||||
}
|
||||
original_buffer.users = std::move(placed_users);
|
||||
original_buffer.unfinished_user_count = 0;
|
||||
original_buffer.users.push_back(compressed_item);
|
||||
Buffer& compressed_buffer =
|
||||
NewBuffer(compressed_item, compressed_item->instruction->shape(),
|
||||
{uncompressed_item}, /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
compressed_item->buffers_used = original_item->buffers_output;
|
||||
compressed_item->buffers_output = {compressed_buffer.id};
|
||||
compressed_item->buffers_defined.push_back(compressed_buffer.id);
|
||||
|
||||
Buffer& uncompressed_buffer =
|
||||
NewBuffer(uncompressed_item, uncompressed_item->instruction->shape(),
|
||||
std::move(unplaced_users), /*live_out=*/false,
|
||||
/*has_indirect_uses=*/false);
|
||||
|
||||
uncompressed_item->buffers_used = {compressed_item->buffers_output[0]};
|
||||
uncompressed_item->buffers_output = {uncompressed_buffer.id};
|
||||
uncompressed_item->buffers_defined = {uncompressed_buffer.id};
|
||||
|
||||
for (Item* user : uncompressed_buffer.users) {
|
||||
BufferIdList& buffers_used = user->buffers_used;
|
||||
std::replace(buffers_used.begin(), buffers_used.end(), original_buffer_id,
|
||||
uncompressed_buffer.id);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MemoryUsageTracker::AddRematerializedInstruction(Item* original_item,
|
||||
Item* remat_item) {
|
||||
VLOG(3) << "AddRematerializedInstruction: original_instruction = "
|
||||
@ -831,6 +1021,17 @@ string MemoryUsageTracker::ToString() const {
|
||||
return output;
|
||||
}
|
||||
|
||||
StatusOr<Shape> MemoryUsageTracker::GetCompactShape(const HloInstruction* hlo) {
|
||||
auto it = compact_shape_.find(hlo);
|
||||
if (it != compact_shape_.end()) {
|
||||
return it->second;
|
||||
}
|
||||
const Shape& original_shape = hlo->shape();
|
||||
TF_ASSIGN_OR_RETURN(Shape min_shape, compact_shape_function_(original_shape));
|
||||
compact_shape_[hlo] = min_shape;
|
||||
return min_shape;
|
||||
}
|
||||
|
||||
bool MemoryUsageTracker::Check() const {
|
||||
auto elements_are_unique = [](const BufferIdList& vec) {
|
||||
return vec.size() == std::set<BufferId>(vec.begin(), vec.end()).size();
|
||||
@ -917,12 +1118,15 @@ int64 RematerializationCost(const HloInstruction* instruction,
|
||||
// candidate which reduce memory use at the program point of the current
|
||||
// instruction as indicated by memory_tracker. nullptr is returned if no
|
||||
// candidate can be found.
|
||||
Item* PickRematerializationCandidate(
|
||||
const MemoryUsageTracker& memory_tracker,
|
||||
std::pair<Item*, RematStrategy>
|
||||
MemoryUsageTracker::PickRematerializationCandidate(
|
||||
const InstructionList& instruction_list, int64 memory_limit_bytes,
|
||||
absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
|
||||
Item* best_item = nullptr;
|
||||
int64 best_cost = 0;
|
||||
RematStrategy best_strategy;
|
||||
|
||||
VLOG(5) << "Picking candidate";
|
||||
|
||||
// TODO(b/35244891): This is currently quadratic in the number of HLO
|
||||
// instructions.
|
||||
@ -947,44 +1151,215 @@ Item* PickRematerializationCandidate(
|
||||
if (!CanBeRematerialized(candidate, remat_able)) {
|
||||
VLOG(5) << "candidate " << candidate->name()
|
||||
<< " not viable: is not rematerializable";
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// If any of the candidate's control successor has been placed, we need to
|
||||
// skip this candidate. Otherwise we will violate control dependency.
|
||||
bool control_successor_placed =
|
||||
std::any_of(candidate->control_successors().begin(),
|
||||
candidate->control_successors().end(),
|
||||
[&memory_tracker](const HloInstruction* inst) {
|
||||
return memory_tracker.IsPlaced(inst);
|
||||
});
|
||||
if (item->buffers_output.size() == 1) {
|
||||
// Only consider compressing single output instruction.
|
||||
const Buffer& output_buffer = buffers_.at(item->buffers_output[0]);
|
||||
|
||||
if (item->placed && item != in_progress_item_ &&
|
||||
!output_buffer.live_out) {
|
||||
const Shape& original_shape = item->instruction->shape();
|
||||
if (original_shape.IsArray()) {
|
||||
Shape compact_shape = GetCompactShape(item->instruction).ValueOrDie();
|
||||
const int64 memory_reduced =
|
||||
MemoryReducedIfCompressed(item, compact_shape);
|
||||
if (memory_reduced > 0) {
|
||||
const int64 cost = memory_limit_bytes / memory_reduced;
|
||||
if (best_item == nullptr || cost < best_cost) {
|
||||
VLOG(3) << "candidate " << candidate->name() << "("
|
||||
<< candidate->ToShortString() << ")"
|
||||
<< " now best when compressed into "
|
||||
<< compact_shape.ToString(true);
|
||||
RematStrategy strategy;
|
||||
strategy.kind = RematStrategy::kCompress;
|
||||
best_strategy = strategy;
|
||||
best_strategy.compact_shape = compact_shape;
|
||||
best_item = item;
|
||||
best_cost = cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If any of the candidate's control successor has been placed, we need
|
||||
// to skip this candidate. Otherwise we will violate control dependency.
|
||||
bool control_successor_placed = std::any_of(
|
||||
candidate->control_successors().begin(),
|
||||
candidate->control_successors().end(),
|
||||
[this](const HloInstruction* inst) { return IsPlaced(inst); });
|
||||
|
||||
if (control_successor_placed) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64 memory_reduced =
|
||||
memory_tracker.MemoryReducedIfRematerialized(item);
|
||||
const int64 memory_reduced = MemoryReducedIfRematerialized(item);
|
||||
|
||||
if (memory_reduced <= 0) {
|
||||
VLOG(5) << "candidate " << candidate->name()
|
||||
<< " memory reduced = " << memory_reduced << " <= 0";
|
||||
continue;
|
||||
}
|
||||
if (memory_reduced > 0) {
|
||||
const int cost =
|
||||
RematerializationCost(candidate, memory_reduced, memory_limit_bytes);
|
||||
|
||||
const int cost = RematerializationCost(candidate, memory_tracker,
|
||||
memory_reduced, memory_limit_bytes);
|
||||
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
|
||||
<< memory_reduced << ", cost per byte " << cost;
|
||||
|
||||
VLOG(5) << "candidate " << candidate->name() << ", memory reduced "
|
||||
<< memory_reduced << ", cost per byte " << cost;
|
||||
|
||||
if (best_item == nullptr || cost < best_cost) {
|
||||
VLOG(5) << "candidate " << candidate->name() << " now best";
|
||||
best_item = item;
|
||||
best_cost = cost;
|
||||
if (best_item == nullptr || cost < best_cost) {
|
||||
VLOG(5) << "candidate " << candidate->name() << " now best";
|
||||
best_strategy.kind = RematStrategy::kRecompute;
|
||||
best_item = item;
|
||||
best_cost = cost;
|
||||
}
|
||||
}
|
||||
}
|
||||
return best_item;
|
||||
return {best_item, best_strategy};
|
||||
}
|
||||
|
||||
StatusOr<int64> RematerializeInstruction(
|
||||
MemoryUsageTracker* memory_tracker, Item* best_item,
|
||||
absl::flat_hash_set<const HloInstruction*>* remat_move_instructions,
|
||||
InstructionList* instruction_list) {
|
||||
HloInstruction* best = best_item->instruction;
|
||||
VLOG(1) << "Rematerializing instruction " << best->name() << " (saving "
|
||||
<< HumanReadableNumBytes(
|
||||
memory_tracker->MemoryReducedIfRematerialized(best_item))
|
||||
<< ")";
|
||||
|
||||
int64 net_instructions_added = 0;
|
||||
|
||||
HloComputation* computation = best->parent();
|
||||
|
||||
HloInstruction* remat =
|
||||
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
|
||||
|
||||
// Add control dependencies to the new operation.
|
||||
for (auto successor : best->control_successors()) {
|
||||
TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
|
||||
}
|
||||
for (auto predecessor : best->control_predecessors()) {
|
||||
TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
|
||||
}
|
||||
|
||||
Item* remat_item = instruction_list->CreateItem(remat);
|
||||
|
||||
// Replace each remaining use of 'best' with the rematerialization.
|
||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
||||
for (HloInstruction* user : best_users_copy) {
|
||||
if (!memory_tracker->IsPlaced(user)) {
|
||||
VLOG(2) << " Replacing use of " << best->name() << " in " << user->name()
|
||||
<< " with " << remat->name();
|
||||
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
|
||||
}
|
||||
}
|
||||
|
||||
// Account for the rematerialization in the memory tracker.
|
||||
TF_RETURN_IF_ERROR(
|
||||
memory_tracker->AddRematerializedInstruction(best_item, remat_item));
|
||||
|
||||
// Insert rematerialized instruction right before the earliest unplaced
|
||||
// use of the instruction *and* the earliest unplaced last use of any
|
||||
// operands of remat. Unplaced uses of the remat's operands are included
|
||||
// because we don't want to extend the live range of remat's operands as
|
||||
// this could increase memory usage.
|
||||
ItemList place_before;
|
||||
for (auto user : remat->users()) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
for (auto* operand : remat->operands()) {
|
||||
for (auto* operand_user : operand->users()) {
|
||||
if (operand_user != remat) {
|
||||
Item* operand_user_item = instruction_list->GetItem(operand_user);
|
||||
if (!operand_user_item->placed) {
|
||||
place_before.push_back(operand_user_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert rematerialized instruction before any of its successors to
|
||||
// preserve ordering regarding control dependency.
|
||||
for (auto successor : remat->control_successors()) {
|
||||
Item* successor_item = instruction_list->GetItem(successor);
|
||||
// Assert to make sure we never remat an operation with control
|
||||
// successor already placed.
|
||||
CHECK(!successor_item->placed) << successor_item->instruction->name();
|
||||
place_before.push_back(successor_item);
|
||||
}
|
||||
instruction_list->InsertBeforeInstructions(remat_item, place_before);
|
||||
|
||||
// If the rematerialized instruction is dead then rematerialization is
|
||||
// essentially a move. Don't delete the instruction now because we don't
|
||||
// want duplicate HloInstruction* values during the course of the
|
||||
// transformation because we keep maps with HloInstruction* values as
|
||||
// keys.
|
||||
if (best->users().empty()) {
|
||||
VLOG(2) << best->name() << " is now dead";
|
||||
if (ContainsKey(*remat_move_instructions, best)) {
|
||||
// Previously, 'best' was a rematerialization which killed the
|
||||
// instruction it was a copying of. Now 'remat' is a rematerialization
|
||||
// of 'best' and kills 'best'. Stop rematerializing this instruction
|
||||
// to avoid an infinite loop.
|
||||
instruction_list->Blacklist(remat);
|
||||
}
|
||||
remat_move_instructions->insert(remat);
|
||||
|
||||
} else {
|
||||
net_instructions_added++;
|
||||
}
|
||||
return net_instructions_added;
|
||||
}
|
||||
|
||||
StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
|
||||
Item* best_item, const Shape& compact_shape,
|
||||
InstructionList* instruction_list) {
|
||||
HloInstruction* best = best_item->instruction;
|
||||
VLOG(5) << "Transposing instruction " << best->name() << " (saving "
|
||||
<< HumanReadableNumBytes(memory_tracker->MemoryReducedIfCompressed(
|
||||
best_item, compact_shape))
|
||||
<< ") to" << compact_shape.ToString(true);
|
||||
|
||||
HloComputation* computation = best->parent();
|
||||
|
||||
HloInstruction* compressed = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best));
|
||||
|
||||
HloInstruction* uncompressed = computation->AddInstruction(
|
||||
HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed));
|
||||
|
||||
Item* compressed_item = instruction_list->CreateItem(compressed);
|
||||
compressed_item->placed = true;
|
||||
|
||||
Item* uncompressed_item = instruction_list->CreateItem(uncompressed);
|
||||
|
||||
// Replace each remaining use of 'best' with the uncompressed.
|
||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
||||
for (HloInstruction* user : best_users_copy) {
|
||||
if (!memory_tracker->IsPlaced(user)) {
|
||||
VLOG(5) << " Replacing use of " << best->name() << " in " << user->name()
|
||||
<< " with " << uncompressed->name();
|
||||
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, uncompressed));
|
||||
}
|
||||
}
|
||||
|
||||
// Account for the rematerialization in the memory tracker.
|
||||
TF_RETURN_IF_ERROR(memory_tracker->AddCompressInstructions(
|
||||
best_item, compressed_item, uncompressed_item));
|
||||
|
||||
// Insert rematerialized instruction right before the earliest unplaced
|
||||
// use of the instruction *and* the earliest unplaced last use of any
|
||||
// operands of remat. Unplaced uses of the remat's operands are included
|
||||
// because we don't want to extend the live range of remat's operands as
|
||||
// this could increase memory usage.
|
||||
ItemList place_before;
|
||||
for (auto user : uncompressed->users()) {
|
||||
place_before.push_back(instruction_list->GetItem(user));
|
||||
}
|
||||
|
||||
instruction_list->InsertBeforeInstructions(uncompressed_item, place_before);
|
||||
|
||||
instruction_list->InsertAfterInstructions(compressed_item, {best_item});
|
||||
|
||||
return 2;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -993,7 +1368,8 @@ StatusOr<int64> HloRematerialization::ComputePeakMemory(
|
||||
const HloComputation* computation,
|
||||
const HloInstructionSequence& order) const {
|
||||
InstructionList instruction_list(order);
|
||||
MemoryUsageTracker tracker(computation, size_function_, *points_to_analysis_,
|
||||
MemoryUsageTracker tracker(computation, size_function_,
|
||||
compact_shape_function_, *points_to_analysis_,
|
||||
instruction_list);
|
||||
int64 peak_memory = tracker.memory_usage();
|
||||
for (auto* item = instruction_list.first(); item != nullptr;
|
||||
@ -1037,6 +1413,7 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
|
||||
InstructionList instruction_list(schedule->sequence(computation));
|
||||
MemoryUsageTracker memory_tracker(computation, size_function_,
|
||||
compact_shape_function_,
|
||||
*points_to_analysis_, instruction_list);
|
||||
bool changed = false;
|
||||
|
||||
@ -1086,8 +1463,11 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
callee_usage)
|
||||
<< ", limit is " << HumanReadableNumBytes(memory_limit_bytes);
|
||||
|
||||
Item* best_item = PickRematerializationCandidate(
|
||||
memory_tracker, instruction_list, memory_limit_bytes, &remat_able);
|
||||
Item* best_item;
|
||||
RematStrategy best_strategy;
|
||||
std::tie(best_item, best_strategy) =
|
||||
memory_tracker.PickRematerializationCandidate(
|
||||
instruction_list, memory_limit_bytes, &remat_able);
|
||||
|
||||
if (best_item == nullptr) {
|
||||
VLOG(3) << "Unable to find rematerialization candidate at program "
|
||||
@ -1106,81 +1486,19 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
|
||||
changed = true;
|
||||
remat_count++;
|
||||
|
||||
HloInstruction* remat =
|
||||
computation->AddInstruction(best->Clone(/*suffix=*/"remat"));
|
||||
|
||||
// Add control dependencies to the new operation.
|
||||
for (auto successor : best->control_successors()) {
|
||||
TF_RETURN_IF_ERROR(remat->AddControlDependencyTo(successor));
|
||||
}
|
||||
for (auto predecessor : best->control_predecessors()) {
|
||||
TF_RETURN_IF_ERROR(predecessor->AddControlDependencyTo(remat));
|
||||
}
|
||||
|
||||
Item* remat_item = instruction_list.CreateItem(remat);
|
||||
|
||||
// Replace each remaining use of 'best' with the rematerialization.
|
||||
std::vector<HloInstruction*> best_users_copy = best->users();
|
||||
for (HloInstruction* user : best_users_copy) {
|
||||
if (!memory_tracker.IsPlaced(user)) {
|
||||
VLOG(2) << " Replacing use of " << best->name() << " in "
|
||||
<< user->name() << " with " << remat->name();
|
||||
TF_RETURN_IF_ERROR(best->ReplaceUseWith(user, remat));
|
||||
}
|
||||
}
|
||||
|
||||
// Account for the rematerialization in the memory tracker.
|
||||
TF_RETURN_IF_ERROR(
|
||||
memory_tracker.AddRematerializedInstruction(best_item, remat_item));
|
||||
|
||||
// Insert rematerialized instruction right before the earliest unplaced
|
||||
// use of the instruction *and* the earliest unplaced last use of any
|
||||
// operands of remat. Unplaced uses of the remat's operands are included
|
||||
// because we don't want to extend the live range of remat's operands as
|
||||
// this could increase memory usage.
|
||||
ItemList place_before;
|
||||
for (auto user : remat->users()) {
|
||||
place_before.push_back(instruction_list.GetItem(user));
|
||||
}
|
||||
for (auto* operand : remat->operands()) {
|
||||
for (auto* operand_user : operand->users()) {
|
||||
if (operand_user != remat) {
|
||||
Item* operand_user_item = instruction_list.GetItem(operand_user);
|
||||
if (!operand_user_item->placed) {
|
||||
place_before.push_back(operand_user_item);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert rematerialized instruction before any of its successors to
|
||||
// preserve ordering regarding control dependency.
|
||||
for (auto successor : remat->control_successors()) {
|
||||
Item* successor_item = instruction_list.GetItem(successor);
|
||||
// Assert to make sure we never remat an operation with control
|
||||
// successor already placed.
|
||||
CHECK(!successor_item->placed) << successor_item->instruction->name();
|
||||
place_before.push_back(successor_item);
|
||||
}
|
||||
instruction_list.InsertBeforeInstructions(remat_item, place_before);
|
||||
|
||||
// If the rematerialized instruction is dead then rematerialization is
|
||||
// essentially a move. Don't delete the instruction now because we don't
|
||||
// want duplicate HloInstruction* values during the course of the
|
||||
// transformation because we keep maps with HloInstruction* values as
|
||||
// keys.
|
||||
if (best->users().empty()) {
|
||||
VLOG(2) << best->name() << " is now dead";
|
||||
if (ContainsKey(remat_move_instructions, best)) {
|
||||
// Previously, 'best' was a rematerialization which killed the
|
||||
// instruction it was a copying of. Now 'remat' is a rematerialization
|
||||
// of 'best' and kills 'best'. Stop rematerializing this instruction
|
||||
// to avoid an infinite loop.
|
||||
instruction_list.Blacklist(remat);
|
||||
}
|
||||
remat_move_instructions.insert(remat);
|
||||
int64 added_instruction = 0;
|
||||
if (best_strategy.kind == RematStrategy::kCompress) {
|
||||
TF_ASSIGN_OR_RETURN(added_instruction,
|
||||
CompressInstruction(&memory_tracker, best_item,
|
||||
best_strategy.compact_shape,
|
||||
&instruction_list));
|
||||
} else {
|
||||
net_instructions_added++;
|
||||
TF_ASSIGN_OR_RETURN(added_instruction,
|
||||
RematerializeInstruction(&memory_tracker, best_item,
|
||||
&remat_move_instructions,
|
||||
&instruction_list));
|
||||
}
|
||||
net_instructions_added += added_instruction;
|
||||
|
||||
VLOG(1) << "memory_usage after rematerialization = "
|
||||
<< HumanReadableNumBytes(memory_tracker.memory_usage());
|
||||
@ -1357,7 +1675,7 @@ StatusOr<bool> HloRematerialization::Run(HloModule* module) {
|
||||
sizes_->after_bytes = current_peak_memory;
|
||||
}
|
||||
|
||||
XLA_VLOG_LINES(3, "After HloRematerialization:\n" + module->ToString());
|
||||
XLA_VLOG_LINES(5, "After HloRematerialization:\n" + module->ToString());
|
||||
|
||||
if (current_peak_memory > memory_limit_bytes_) {
|
||||
LOG(WARNING) << absl::StrFormat(
|
||||
|
@ -24,6 +24,8 @@
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
|
||||
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
@ -38,6 +40,8 @@ class HloRematerialization : public HloModulePass {
|
||||
public:
|
||||
using ShapeSizeFunction = std::function<int64(const Shape&)>;
|
||||
|
||||
using CompactShapeFunction = std::function<StatusOr<Shape>(const Shape&)>;
|
||||
|
||||
// Helper struct that communicates the before / after sizes for the
|
||||
// rematerialization process.
|
||||
struct RematerializationSizes {
|
||||
@ -45,6 +49,8 @@ class HloRematerialization : public HloModulePass {
|
||||
int64 after_bytes;
|
||||
};
|
||||
|
||||
static Shape DefaultCompactShapeFunction(const Shape& shape) { return shape; }
|
||||
|
||||
// Constructor parameters:
|
||||
//
|
||||
// size_function: Function which returns the size in bytes of the top-level
|
||||
@ -57,12 +63,20 @@ class HloRematerialization : public HloModulePass {
|
||||
// sizes: Pointer to data structure which records the peak memory usage of
|
||||
// the HLO module before/after rematerialization. Value are set during
|
||||
// Run(). Can be nullptr.
|
||||
HloRematerialization(const ShapeSizeFunction& size_function,
|
||||
int64 memory_limit_bytes, RematerializationSizes* sizes)
|
||||
//
|
||||
// compact_shape_function: Function which returns the compact form of a
|
||||
// shape. If nullptr is provided, an default identity function is used.
|
||||
explicit HloRematerialization(
|
||||
const ShapeSizeFunction& size_function, int64 memory_limit_bytes,
|
||||
RematerializationSizes* sizes,
|
||||
CompactShapeFunction compact_shape_function = nullptr)
|
||||
: size_function_(size_function),
|
||||
memory_limit_bytes_(memory_limit_bytes),
|
||||
sizes_(sizes) {}
|
||||
~HloRematerialization() {}
|
||||
sizes_(sizes),
|
||||
compact_shape_function_(compact_shape_function == nullptr
|
||||
? DefaultCompactShapeFunction
|
||||
: std::move(compact_shape_function)) {}
|
||||
~HloRematerialization() override = default;
|
||||
|
||||
absl::string_view name() const override { return "rematerialization"; }
|
||||
|
||||
@ -109,6 +123,10 @@ class HloRematerialization : public HloModulePass {
|
||||
// module before/after rematerialization
|
||||
RematerializationSizes* sizes_;
|
||||
|
||||
// Converts a shape into compact form, returns the same shape if a shape is
|
||||
// already considered compact.
|
||||
const CompactShapeFunction compact_shape_function_;
|
||||
|
||||
// Call graph of the hlo_module.
|
||||
std::unique_ptr<CallGraph> call_graph_;
|
||||
|
||||
|
@ -27,7 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
@ -534,6 +533,142 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) {
|
||||
INSTANTIATE_TEST_SUITE_P(IndirectUseTestInstantiation, IndirectUseTest,
|
||||
::testing::Values(true, false));
|
||||
|
||||
class CompressingRematerializationTest : public RematerializationTestBase {
|
||||
protected:
|
||||
// A special shape size function, which pads the most minor dimension to 64.
|
||||
static int64 ShapeSizePadMinorTo64(const Shape& shape) {
|
||||
if (shape.IsTuple()) {
|
||||
// Size of a tuple is 4 bytes.
|
||||
return 4;
|
||||
}
|
||||
Shape descending_shape =
|
||||
ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(shape);
|
||||
int64 size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(descending_shape.element_type());
|
||||
for (int64 i = 0; i < descending_shape.rank(); ++i) {
|
||||
int64 dim = shape.dimensions(i);
|
||||
if (i == descending_shape.rank() - 1) {
|
||||
dim = RoundUpToNearest<int64>(dim, 64);
|
||||
}
|
||||
size *= dim;
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
// Swap the two most-minor dimensions if the second-minor dimension is bigger
|
||||
// than the most-minor dimension.
|
||||
static StatusOr<Shape> ChooseCompactLayoutForShape(const Shape& shape) {
|
||||
Shape result = shape;
|
||||
Layout layout = result.layout();
|
||||
int64 most_minor_index = layout.minor_to_major()[0];
|
||||
int64 second_minor_index = layout.minor_to_major()[1];
|
||||
int64 most_minor = result.dimensions(most_minor_index);
|
||||
int64 second_minor = result.dimensions(second_minor_index);
|
||||
if (most_minor < second_minor) {
|
||||
result.set_dimensions(most_minor_index, second_minor);
|
||||
result.set_dimensions(second_minor_index, most_minor);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
StatusOr<bool> RunHloRematerialization(int64 memory_limit_bytes,
|
||||
HloModule* module) {
|
||||
TF_EXPECT_OK(verifier().Run(module).status());
|
||||
HloRematerialization remat(ShapeSizePadMinorTo64, memory_limit_bytes,
|
||||
/*sizes=*/nullptr, ChooseCompactLayoutForShape);
|
||||
return remat.Run(module);
|
||||
}
|
||||
};
|
||||
|
||||
// Test rematerialization of a single instruction.
|
||||
TEST_F(CompressingRematerializationTest, SingleRemat) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
%add_float {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param.0 = f32[] parameter(0)
|
||||
%constant = f32[] constant(0)
|
||||
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
|
||||
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
|
||||
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/30 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
HloInstruction* broadcast =
|
||||
module->entry_computation()->GetInstructionWithName("broadcast.0");
|
||||
HloInstruction* reduce =
|
||||
module->entry_computation()->GetInstructionWithName("reduce.1");
|
||||
EXPECT_THAT(reduce,
|
||||
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
|
||||
}
|
||||
|
||||
TEST_F(CompressingRematerializationTest, AllUsersUseSameCopy) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule fusion, is_scheduled=true
|
||||
|
||||
%add_float {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(f32[] %x, f32[] %y)
|
||||
}
|
||||
|
||||
ENTRY %entry {
|
||||
%param.0 = f32[] parameter(0)
|
||||
%constant = f32[] constant(0)
|
||||
%broadcast.0 = f32[64,2]{1,0} broadcast(f32[] %param.0), dimensions={}
|
||||
%negate = f32[64,2]{1,0} negate(f32[64,2]{1,0} broadcast.0)
|
||||
%reduce.0 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%reduce.1 = f32[] reduce(f32[64,2]{1,0} %negate, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%reduce.2 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%add = f32[] add(f32[] %reduce.0, f32[] %reduce.1)
|
||||
%reduce.3 = f32[] reduce(f32[64,2]{1,0} %broadcast.0, f32[] %constant), dimensions={1, 0}, to_apply=%add_float
|
||||
%add.2 = f32[] add(f32[] %reduce.2, f32[] %reduce.3)
|
||||
ROOT %tuple = (f32[], f32[]) tuple (f32[] add, f32[] add.2)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto module,
|
||||
HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
RunHloRematerialization(
|
||||
/*memory_limit_bytes=*/30 * 1024, module.get()));
|
||||
EXPECT_TRUE(changed);
|
||||
|
||||
HloInstruction* broadcast =
|
||||
module->entry_computation()->GetInstructionWithName("broadcast.0");
|
||||
|
||||
// Both reduces reuse the same copy instruction.
|
||||
HloInstruction* reduce_2 =
|
||||
module->entry_computation()->GetInstructionWithName("reduce.2");
|
||||
|
||||
HloInstruction* reduce_3 =
|
||||
module->entry_computation()->GetInstructionWithName("reduce.3");
|
||||
|
||||
EXPECT_THAT(reduce_2,
|
||||
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
|
||||
|
||||
EXPECT_THAT(reduce_3,
|
||||
op::Reduce(op::Copy(op::Copy(broadcast)), op::Constant()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user