From a97f22586d5ec1085e5eb69f4c9cc696ff6b58d1 Mon Sep 17 00:00:00 2001 From: Yunxing Dai Date: Thu, 8 Oct 2020 12:55:51 -0700 Subject: [PATCH] Add add new option in HloComputation::AddInstruction to add instruction with different name. PiperOrigin-RevId: 336146593 Change-Id: I57ea0152169fc54188cf1002bda772b6153070c9 --- tensorflow/compiler/xla/service/hlo_computation.cc | 5 ++++- tensorflow/compiler/xla/service/hlo_computation.h | 3 ++- tensorflow/compiler/xla/service/hlo_rematerialization.cc | 9 ++++----- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 5d695b9c20f..6323d0903a4 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -93,10 +93,13 @@ HloComputation::HloComputation( } HloInstruction* HloComputation::AddInstruction( - std::unique_ptr instruction) { + std::unique_ptr instruction, const std::string& new_name) { CHECK(instruction->opcode() != HloOpcode::kParameter) << "Parameter instructions cannot be added to a computation after " << "it has been built"; + if (!new_name.empty()) { + instruction->SetAndSanitizeName(new_name); + } return AddInstructionInternal(std::move(instruction)); } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 1dcf1d9d7d3..d618a527070 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -122,7 +122,8 @@ class HloComputation { // Add an instruction to the computation. The computation takes ownership of // the instruction. - HloInstruction* AddInstruction(std::unique_ptr instruction); + HloInstruction* AddInstruction(std::unique_ptr instruction, + const std::string& new_name = ""); // Remove the param_no'th parameter from the computation. // Note this is only applicatable to the computation for the fusion diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 790d4bfc2fb..59b1ac31e9b 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -1521,14 +1521,13 @@ StatusOr CompressInstruction(MemoryUsageTracker* memory_tracker, << ") to" << compact_shape.ToString(true); HloComputation* computation = best->parent(); - HloInstruction* compressed = computation->AddInstruction( - HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best)); - compressed->SetAndSanitizeName(best->name() + ".remat_compressed"); + HloInstruction::CreateUnary(compact_shape, HloOpcode::kCopy, best), + /*new_name=*/best->name() + ".remat_compressed"); HloInstruction* uncompressed = computation->AddInstruction( - HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed)); - uncompressed->SetAndSanitizeName(best->name() + ".remat_uncompressed"); + HloInstruction::CreateUnary(best->shape(), HloOpcode::kCopy, compressed), + /*new_name=*/best->name() + ".remat_uncompressed"); Item* compressed_item = instruction_list->CreateItem(compressed); compressed_item->placed = true;