Preserve unique ids when serializing/deserializing HLO protos.
Re-assigning unique IDs broke serialization of HloSchedule, and keeping IDs stable improves the fidelity of the proto serialization. This change requires that instructions in HLO module protos have valid, module-scope-unique ids so change the XLA builder to hand out module-scope-unique ids. Previously, instruction ids were only unique in the computation scope. PiperOrigin-RevId: 212692339
This commit is contained in:
parent
3fb474713b
commit
5d1de24583
@ -231,6 +231,7 @@ tf_cc_test(
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//third_party/eigen3",
|
||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
string hlo_profile_as_string =
|
||||
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
|
||||
/*clock_rate_ghz=*/1.0);
|
||||
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
|
||||
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
|
||||
|
||||
// Strip away identifier details from the profile string to avoid this test
|
||||
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
|
||||
// just become '%dot'.
|
||||
RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
|
||||
VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
|
||||
|
||||
std::vector<string> hlo_profile_lines =
|
||||
absl::StrSplit(hlo_profile_as_string, '\n');
|
||||
@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) {
|
||||
auto header = HasSubstr("Execution profile for");
|
||||
auto total_cycles_profile_line = HasSubstr("[total]");
|
||||
auto dot_profile_line = HasSubstr(
|
||||
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||
"%arg1.0.1)");
|
||||
"%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
|
||||
auto add_profile_line = HasSubstr(
|
||||
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
||||
"%arg1.0.1)");
|
||||
"%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
|
||||
auto tuple_profile_line = HasSubstr(
|
||||
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
|
||||
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
|
||||
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
|
||||
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
|
||||
"%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
|
||||
"f32[2,2]{1,0} %add)");
|
||||
auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
|
||||
auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
|
||||
|
||||
EXPECT_THAT(hlo_profile_lines,
|
||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||
|
@ -604,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
|
||||
auto instr1 = c1.instructions(j);
|
||||
auto instr2 = c2.instructions(j);
|
||||
instr1.clear_name();
|
||||
instr1.clear_id();
|
||||
instr1.clear_operand_ids();
|
||||
instr2.clear_name();
|
||||
// The names of instructions were uniquified by the XlaBuilder, the rest
|
||||
// of the fields should be identical.
|
||||
instr2.clear_id();
|
||||
instr2.clear_operand_ids();
|
||||
// The names of instructions were uniquified by the XlaBuilder and the
|
||||
// unique ids may be different, the rest of the fields should be
|
||||
// identical.
|
||||
string str1, str2;
|
||||
LOG(INFO) << "instr1 = " << instr1.DebugString();
|
||||
LOG(INFO) << "instr2 = " << instr2.DebugString();
|
||||
instr1.AppendPartialToString(&str1);
|
||||
instr2.AppendPartialToString(&str2);
|
||||
EXPECT_EQ(str1, str2);
|
||||
|
@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
|
||||
|
||||
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
|
||||
LookUpInstructionByHandle(root_id));
|
||||
|
||||
ProgramShape program_shape;
|
||||
|
||||
*program_shape.mutable_result() = instructions_[root_id].shape();
|
||||
*program_shape.mutable_result() = root_proto->shape();
|
||||
|
||||
// Check that the parameter numbers are continuous from 0, and add parameter
|
||||
// shapes and names to the program shape.
|
||||
@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
||||
return;
|
||||
}
|
||||
|
||||
CHECK(op_handle < instructions_.size() && op_handle >= 0);
|
||||
|
||||
const HloInstructionProto& instr = instructions_[op_handle];
|
||||
const HloInstructionProto& instr =
|
||||
*(LookUpInstructionByHandle(op_handle).ValueOrDie());
|
||||
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
|
||||
switch (opcode) {
|
||||
default:
|
||||
@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
||||
|
||||
// Clear data held by this builder.
|
||||
this->instructions_.clear();
|
||||
this->handle_to_index_.clear();
|
||||
this->embedded_.clear();
|
||||
this->parameter_numbers_.clear();
|
||||
|
||||
@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
*program_shape->mutable_result() = root->shape();
|
||||
|
||||
// We use std::set to keep the instruction ids in ascending order (which is
|
||||
// also a valid denpendency order). The related ops will be added to the
|
||||
// also a valid dependency order). The related ops will be added to the
|
||||
// subgraph in the same order.
|
||||
std::set<int64> related_ops;
|
||||
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
|
||||
@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
worklist.push(root->id());
|
||||
related_ops.insert(root->id());
|
||||
while (!worklist.empty()) {
|
||||
int64 node = worklist.front();
|
||||
int64 handle = worklist.front();
|
||||
worklist.pop();
|
||||
for (int64 id : instructions_[node].operand_ids()) {
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||
LookUpInstructionByHandle(handle));
|
||||
for (int64 id : instr_proto->operand_ids()) {
|
||||
if (related_ops.insert(id).second) {
|
||||
worklist.push(id);
|
||||
}
|
||||
}
|
||||
for (int64 called_id : instructions_[node].called_computation_ids()) {
|
||||
for (int64 called_id : instr_proto->called_computation_ids()) {
|
||||
related_calls.insert(called_id);
|
||||
}
|
||||
}
|
||||
@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
// Add related ops to the computation.
|
||||
for (int64 id : related_ops) {
|
||||
auto* instr = entry.add_instructions();
|
||||
*instr = instructions_[id];
|
||||
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
|
||||
LookUpInstructionByHandle(id));
|
||||
*instr = *instr_src;
|
||||
// Ensures that the instruction names are unique among the graph.
|
||||
const string& new_name =
|
||||
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
|
||||
@ -2415,7 +2420,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
||||
absl::Span<const XlaOp> operands) {
|
||||
TF_RETURN_IF_ERROR(first_error_);
|
||||
|
||||
const int64 handle = instructions_.size();
|
||||
const int64 handle = GetUniqueId();
|
||||
instr.set_id(handle);
|
||||
instr.set_opcode(HloOpcodeString(opcode));
|
||||
if (instr.name().empty()) {
|
||||
@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
||||
*instr.mutable_sharding() = *sharding_;
|
||||
}
|
||||
|
||||
instructions_.push_back(instr);
|
||||
handle_to_index_[handle] = instructions_.size();
|
||||
instructions_.push_back(std::move(instr));
|
||||
|
||||
XlaOp op(handle, this);
|
||||
return op;
|
||||
@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
||||
op.handle(), op.builder_->name(), this->name());
|
||||
}
|
||||
|
||||
if (op.handle() >= instructions_.size() || op.handle() < 0) {
|
||||
return InvalidArgument("no XlaOp value %d", op.handle());
|
||||
return LookUpInstructionByHandle(op.handle());
|
||||
}
|
||||
|
||||
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
|
||||
int64 handle) const {
|
||||
auto it = handle_to_index_.find(handle);
|
||||
if (it == handle_to_index_.end()) {
|
||||
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||
}
|
||||
return &instructions_[op.handle()];
|
||||
return &instructions_[it->second];
|
||||
}
|
||||
|
||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/stacktrace.h"
|
||||
@ -955,6 +956,8 @@ class XlaBuilder {
|
||||
HloInstructionProto* instr);
|
||||
|
||||
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
|
||||
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
|
||||
int64 handle) const;
|
||||
|
||||
// Internal helper method that does the building for an arbitrary unary op.
|
||||
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
|
||||
@ -1024,6 +1027,10 @@ class XlaBuilder {
|
||||
// The instructions of this computation.
|
||||
std::vector<HloInstructionProto> instructions_;
|
||||
|
||||
// A map from XlaOp::Handle to the index in the instructions_ vector where the
|
||||
// instruction is held.
|
||||
tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
|
||||
|
||||
// The embedded computations used by this computation. Each computation was
|
||||
// the entry computation of some XlaComputation, the key is the unique id of
|
||||
// that XlaComputation.
|
||||
|
@ -1963,6 +1963,7 @@ tf_cc_test(
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_matchers",
|
||||
":hlo_memory_scheduler",
|
||||
":hlo_parser",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
|
||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||
});
|
||||
|
||||
return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
|
||||
&instructions, root,
|
||||
/*fusion_instruction=*/nullptr));
|
||||
auto computation = absl::WrapUnique(
|
||||
new HloComputation(proto.name(), parameter_count, &instructions, root,
|
||||
/*fusion_instruction=*/nullptr));
|
||||
computation->unique_id_ = proto.id();
|
||||
return std::move(computation);
|
||||
}
|
||||
|
||||
void HloComputation::FuseInstructionsInto(
|
||||
|
@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
instruction->SetAndSanitizeName(proto.name());
|
||||
instruction->metadata_ = proto.metadata();
|
||||
instruction->backend_config_ = proto.backend_config();
|
||||
instruction->unique_id_ = proto.id();
|
||||
|
||||
if (proto.has_sharding()) {
|
||||
TF_ASSIGN_OR_RETURN(const auto& sharding,
|
||||
|
@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
|
||||
|
||||
HloComputation* HloModule::AddComputationInternal(
|
||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||
bool uniquify_names) {
|
||||
bool uniquify_identifiers) {
|
||||
if (is_entry) {
|
||||
CHECK_EQ(nullptr, entry_computation_);
|
||||
entry_computation_ = computation.get();
|
||||
@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
|
||||
}
|
||||
}
|
||||
|
||||
if (uniquify_names) {
|
||||
if (uniquify_identifiers) {
|
||||
computation->UniquifyName(&computation_name_uniquer_);
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
instruction->UniquifyName(&instruction_name_uniquer_);
|
||||
}
|
||||
|
||||
// Pick unique IDs for each instruction.
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
instruction->SetUniqueId(NewUniqueInstructionId());
|
||||
}
|
||||
// Set unique id to this computation.
|
||||
CHECK_NE(computation->root_instruction()->unique_id(), -1)
|
||||
<< "Root has no valid id: " << computation->ToString();
|
||||
computation->SetUniqueId(computation->root_instruction()->unique_id());
|
||||
} else {
|
||||
// Don't uniquify the names of the computation or instruction, but we must
|
||||
// run the names through the uniquifiers to prevent future name collisions
|
||||
// for computations and instructions created later.
|
||||
// for computations and instructions created later. Also, set the
|
||||
// next_unique_id_ to the one greater than the max unique id of any
|
||||
// instruction (or the computation) to avoid ID collisions.
|
||||
computation_name_uniquer_.GetUniqueName(computation->name());
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
instruction_name_uniquer_.GetUniqueName(instruction->name());
|
||||
next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
|
||||
}
|
||||
if (next_unique_id_ < computation->unique_id() + 1) {
|
||||
next_unique_id_ = computation->unique_id() + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Pick unique IDs for each instruction.
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
instruction->SetUniqueId(NewUniqueInstructionId());
|
||||
}
|
||||
// Set unique id to this computation.
|
||||
CHECK_NE(computation->root_instruction()->unique_id(), -1)
|
||||
<< "Root has no valid id: " << computation->ToString();
|
||||
computation->SetUniqueId(computation->root_instruction()->unique_id());
|
||||
|
||||
computation->set_parent(this);
|
||||
computations_.push_back(std::move(computation));
|
||||
return computations_.back().get();
|
||||
@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
|
||||
HloComputation* HloModule::AddEntryComputation(
|
||||
std::unique_ptr<HloComputation> computation) {
|
||||
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
|
||||
/*uniquify_names=*/true);
|
||||
/*uniquify_identifiers=*/true);
|
||||
}
|
||||
|
||||
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
||||
@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
||||
HloComputation* HloModule::AddEmbeddedComputation(
|
||||
std::unique_ptr<HloComputation> computation) {
|
||||
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
|
||||
/*uniquify_names=*/true);
|
||||
/*uniquify_identifiers=*/true);
|
||||
}
|
||||
|
||||
void HloModule::ReplaceComputations(
|
||||
@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
||||
VLOG(2) << "CreateFromProto()";
|
||||
XLA_VLOG_LINES(2, proto.DebugString());
|
||||
|
||||
// The ProgramShape in the passed in module config must match the shapes of
|
||||
// the entry parameters and root.
|
||||
TF_RET_CHECK(proto.has_program_shape())
|
||||
@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
// Don't uniquify names because we want names to be stable across
|
||||
// serialization and deserialization.
|
||||
module->AddComputationInternal(std::move(computation), is_entry,
|
||||
/*uniquify_names=*/false);
|
||||
/*uniquify_identifiers=*/false);
|
||||
}
|
||||
TF_RET_CHECK(module->entry_computation_ != nullptr);
|
||||
|
||||
// Because we didn't uniquify the names, double-check that the instruction and
|
||||
// computation names are unique from the proto.
|
||||
// Because we didn't uniquify the names or the ids, double-check that the
|
||||
// instruction and computation names and ids are unique from the proto.
|
||||
tensorflow::gtl::FlatSet<string> computation_names;
|
||||
tensorflow::gtl::FlatSet<string> instruction_names;
|
||||
tensorflow::gtl::FlatSet<int> computation_ids;
|
||||
tensorflow::gtl::FlatSet<int> instruction_ids;
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
|
||||
<< "Computation name is not unique: " << computation->name();
|
||||
computation_names.insert(computation->name());
|
||||
|
||||
TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
|
||||
<< "Computation id is not unique: " << computation->unique_id();
|
||||
computation_ids.insert(computation->unique_id());
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
|
||||
<< "Instruction name is not unique: " << instruction->name();
|
||||
instruction_names.insert(instruction->name());
|
||||
|
||||
TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
|
||||
<< "Instruction id is not unique: " << instruction->unique_id();
|
||||
instruction_ids.insert(instruction->unique_id());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,7 +253,7 @@ class HloModule {
|
||||
private:
|
||||
HloComputation* AddComputationInternal(
|
||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||
bool uniquify_names);
|
||||
bool uniquify_identifiers);
|
||||
|
||||
const string name_;
|
||||
HloModuleConfig config_;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
@ -253,6 +254,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
||||
op::Broadcast(), op::Multiply(), op::Add()));
|
||||
}
|
||||
|
||||
TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
|
||||
// Verify that serializing then deserializing an HLO proto preserves the
|
||||
// unique IDs of the instruction and module.
|
||||
const string text =
|
||||
R"(HloModule ReduceR3ToR2_module
|
||||
|
||||
add_F32.v3 {
|
||||
lhs = f32[] parameter(0)
|
||||
rhs = f32[] parameter(1)
|
||||
ROOT add = f32[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
ENTRY ReduceR3ToR2.v3 {
|
||||
input = f32[8,16,256]{2,1,0} parameter(0)
|
||||
constant = f32[] constant(0)
|
||||
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseHloString(text));
|
||||
|
||||
// Perform various transformations on the graph:
|
||||
//
|
||||
// * clone the reduction function
|
||||
// * replace use of reduction function with the clone.
|
||||
// * add a random instruction to the entry computation.
|
||||
//
|
||||
// This will create instruction and computation IDs which are interesting:
|
||||
// not consecutive and not densely packed.
|
||||
HloComputation* entry = module->entry_computation();
|
||||
HloInstruction* root = entry->root_instruction();
|
||||
HloComputation* reduction = root->to_apply();
|
||||
HloComputation* reduction_clone =
|
||||
module->AddEmbeddedComputation(reduction->Clone());
|
||||
root->set_to_apply(reduction_clone);
|
||||
TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
|
||||
HloInstruction* negate = entry->AddInstruction(
|
||||
HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
|
||||
entry->set_root_instruction(negate);
|
||||
|
||||
// Schedule the transformed module, this verifies that the serialized schedule
|
||||
// is robust against non-consecutive IDs as well (b/114712358).
|
||||
auto size_fn = [](const BufferValue& buffer) {
|
||||
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||
};
|
||||
HloMemoryScheduler scheduler(size_fn);
|
||||
TF_ASSERT_OK(scheduler.Run(module.get()).status());
|
||||
ASSERT_TRUE(module->has_schedule());
|
||||
|
||||
// Serialize and deserialize and verify that the instruction and computations
|
||||
// unique ids are the same.
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<HloModule> module_copy,
|
||||
HloModule::CreateFromProto(module->ToProto(), module->config()));
|
||||
|
||||
// The module IDs should *not* be the same because module ids must be globally
|
||||
// unique.
|
||||
EXPECT_NE(module->unique_id(), module_copy->unique_id());
|
||||
|
||||
// Verify that the computations and instructions all have the same unique id.
|
||||
auto computation_copy_it = module_copy->computations().begin();
|
||||
for (const HloComputation* computation_orig : module->computations()) {
|
||||
const HloComputation* computation_copy = *computation_copy_it++;
|
||||
EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
|
||||
<< absl::StrFormat(
|
||||
"ID of original computation %s != ID of deserialized "
|
||||
"computation %s: %d != %d",
|
||||
computation_orig->name(), computation_copy->name(),
|
||||
computation_orig->unique_id(), computation_copy->unique_id());
|
||||
|
||||
auto instruction_copy_it = computation_copy->instructions().begin();
|
||||
for (const HloInstruction* instruction_orig :
|
||||
computation_orig->instructions()) {
|
||||
const HloInstruction* instruction_copy = *instruction_copy_it++;
|
||||
EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
|
||||
<< absl::StrFormat(
|
||||
"ID of original instruction %s != ID of deserialized "
|
||||
"instruction %s: %d != %d",
|
||||
instruction_orig->name(), instruction_copy->name(),
|
||||
instruction_orig->unique_id(), instruction_copy->unique_id());
|
||||
}
|
||||
}
|
||||
|
||||
// Verify that the next unique ID which the module would have handed out is
|
||||
// greater than the unique id of any instruction.
|
||||
int next_id = module_copy->NewUniqueInstructionId();
|
||||
for (const HloComputation* computation : module_copy->computations()) {
|
||||
for (const HloInstruction* instruction : computation->instructions()) {
|
||||
EXPECT_GT(next_id, instruction->unique_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace xla
|
||||
|
Loading…
x
Reference in New Issue
Block a user