Add add new option in HloComputation::AddInstruction to add instruction with different name.

PiperOrigin-RevId: 336146593
Change-Id: I57ea0152169fc54188cf1002bda772b6153070c9
This commit is contained in:
Yunxing Dai 2020-10-08 12:55:51 -07:00 committed by TensorFlower Gardener
parent d39d5bf305
commit a97f22586d
3 changed files with 10 additions and 7 deletions

View File

@ -93,10 +93,13 @@ HloComputation::HloComputation(
} }
HloInstruction* HloComputation::AddInstruction( HloInstruction* HloComputation::AddInstruction(
std::unique_ptr<HloInstruction> instruction) { std::unique_ptr<HloInstruction> instruction, const std::string& new_name) {
CHECK(instruction->opcode() != HloOpcode::kParameter) CHECK(instruction->opcode() != HloOpcode::kParameter)
<< "Parameter instructions cannot be added to a computation after " << "Parameter instructions cannot be added to a computation after "
<< "it has been built"; << "it has been built";
if (!new_name.empty()) {
instruction->SetAndSanitizeName(new_name);
}
return AddInstructionInternal(std::move(instruction)); return AddInstructionInternal(std::move(instruction));
} }

View File

@ -122,7 +122,8 @@ class HloComputation {
// Add an instruction to the computation. The computation takes ownership of // Add an instruction to the computation. The computation takes ownership of
// the instruction. // the instruction.
HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction); HloInstruction* AddInstruction(std::unique_ptr<HloInstruction> instruction,
const std::string& new_name = "");
// Remove the param_no'th parameter from the computation. // Remove the param_no'th parameter from the computation.
// Note this is only applicatable to the computation for the fusion // Note this is only applicatable to the computation for the fusion

View File

@ -1521,14 +1521,13 @@ StatusOr<int64> CompressInstruction(MemoryUsageTracker* memory_tracker,
<< ") to" << compact_shape.ToString(true); << ") to" << compact_shape.ToString(true);
HloComputation* computation = best->parent(); HloComputation* computation = best->parent();
HloInstruction* compressed = computation->AddInstruction( HloInstruction* compressed = computation->AddInstruction(
HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best),
compressed->SetAndSanitizeName(best->name() + ".remat_compressed"); /*new_name=*/best->name() + ".remat_compressed");
HloInstruction* uncompressed = computation->AddInstruction( HloInstruction* uncompressed = computation->AddInstruction(
HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed),
uncompressed->SetAndSanitizeName(best->name() + ".remat_uncompressed"); /*new_name=*/best->name() + ".remat_uncompressed");
Item* compressed_item = instruction_list->CreateItem(compressed); Item* compressed_item = instruction_list->CreateItem(compressed);
compressed_item->placed = true; compressed_item->placed = true;