Preserve unique ids when serializing/deserializing HLO protos.

Re-assigning unique IDs broke serialization of HloSchedule, and keeping IDs stable improves the fidelity of the proto serialization. This change requires that instructions in HLO module protos have valid, module-scope-unique ids so change the XLA builder to hand out module-scope-unique ids. Previously, instruction ids were only unique in the computation scope.

PiperOrigin-RevId: 212692339
This commit is contained in:
Mark Heffernan 2018-09-12 13:19:18 -07:00 committed by TensorFlower Gardener
parent 3fb474713b
commit 5d1de24583
11 changed files with 196 additions and 47 deletions

View File

@ -231,6 +231,7 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_profile_printer",
"//tensorflow/core:lib",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//third_party/eigen3",

View File

@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) {
string hlo_profile_as_string =
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
/*clock_rate_ghz=*/1.0);
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
// Strip away identifier details from the profile string to avoid this test
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
// just become '%dot'.
RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
std::vector<string> hlo_profile_lines =
absl::StrSplit(hlo_profile_as_string, '\n');
@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) {
auto header = HasSubstr("Execution profile for");
auto total_cycles_profile_line = HasSubstr("[total]");
auto dot_profile_line = HasSubstr(
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
"%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto add_profile_line = HasSubstr(
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
"%arg1.0.1)");
"%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
auto tuple_profile_line = HasSubstr(
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
"%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
"f32[2,2]{1,0} %add)");
auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
EXPECT_THAT(hlo_profile_lines,
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,

View File

@ -604,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
auto instr1 = c1.instructions(j);
auto instr2 = c2.instructions(j);
instr1.clear_name();
instr1.clear_id();
instr1.clear_operand_ids();
instr2.clear_name();
// The names of instructions were uniquified by the XlaBuilder, the rest
// of the fields should be identical.
instr2.clear_id();
instr2.clear_operand_ids();
// The names of instructions were uniquified by the XlaBuilder and the
// unique ids may be different, the rest of the fields should be
// identical.
string str1, str2;
LOG(INFO) << "instr1 = " << instr1.DebugString();
LOG(INFO) << "instr2 = " << instr2.DebugString();
instr1.AppendPartialToString(&str1);
instr2.AppendPartialToString(&str2);
EXPECT_EQ(str1, str2);

View File

@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
TF_RETURN_IF_ERROR(first_error_);
TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
LookUpInstructionByHandle(root_id));
ProgramShape program_shape;
*program_shape.mutable_result() = instructions_[root_id].shape();
*program_shape.mutable_result() = root_proto->shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
return;
}
CHECK(op_handle < instructions_.size() && op_handle >= 0);
const HloInstructionProto& instr = instructions_[op_handle];
const HloInstructionProto& instr =
*(LookUpInstructionByHandle(op_handle).ValueOrDie());
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
switch (opcode) {
default:
@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
// Clear data held by this builder.
this->instructions_.clear();
this->handle_to_index_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
*program_shape->mutable_result() = root->shape();
// We use std::set to keep the instruction ids in ascending order (which is
// also a valid denpendency order). The related ops will be added to the
// also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
worklist.push(root->id());
related_ops.insert(root->id());
while (!worklist.empty()) {
int64 node = worklist.front();
int64 handle = worklist.front();
worklist.pop();
for (int64 id : instructions_[node].operand_ids()) {
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
LookUpInstructionByHandle(handle));
for (int64 id : instr_proto->operand_ids()) {
if (related_ops.insert(id).second) {
worklist.push(id);
}
}
for (int64 called_id : instructions_[node].called_computation_ids()) {
for (int64 called_id : instr_proto->called_computation_ids()) {
related_calls.insert(called_id);
}
}
@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// Add related ops to the computation.
for (int64 id : related_ops) {
auto* instr = entry.add_instructions();
*instr = instructions_[id];
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
LookUpInstructionByHandle(id));
*instr = *instr_src;
// Ensures that the instruction names are unique among the graph.
const string& new_name =
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
@ -2415,7 +2420,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
absl::Span<const XlaOp> operands) {
TF_RETURN_IF_ERROR(first_error_);
const int64 handle = instructions_.size();
const int64 handle = GetUniqueId();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
*instr.mutable_sharding() = *sharding_;
}
instructions_.push_back(instr);
handle_to_index_[handle] = instructions_.size();
instructions_.push_back(std::move(instr));
XlaOp op(handle, this);
return op;
@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
op.handle(), op.builder_->name(), this->name());
}
if (op.handle() >= instructions_.size() || op.handle() < 0) {
return InvalidArgument("no XlaOp value %d", op.handle());
return LookUpInstructionByHandle(op.handle());
}
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
int64 handle) const {
auto it = handle_to_index_.find(handle);
if (it == handle_to_index_.end()) {
return InvalidArgument("No XlaOp with handle %d", handle);
}
return &instructions_[op.handle()];
return &instructions_[it->second];
}
// Enqueues a "retrieve parameter value" instruction for a parameter that was

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
@ -955,6 +956,8 @@ class XlaBuilder {
HloInstructionProto* instr);
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
int64 handle) const;
// Internal helper method that does the building for an arbitrary unary op.
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
@ -1024,6 +1027,10 @@ class XlaBuilder {
// The instructions of this computation.
std::vector<HloInstructionProto> instructions_;
// A map from XlaOp::Handle to the index in the instructions_ vector where the
// instruction is held.
tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
// that XlaComputation.

View File

@ -1963,6 +1963,7 @@ tf_cc_test(
deps = [
":hlo",
":hlo_matchers",
":hlo_memory_scheduler",
":hlo_parser",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",

View File

@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
&instructions, root,
/*fusion_instruction=*/nullptr));
auto computation = absl::WrapUnique(
new HloComputation(proto.name(), parameter_count, &instructions, root,
/*fusion_instruction=*/nullptr));
computation->unique_id_ = proto.id();
return std::move(computation);
}
void HloComputation::FuseInstructionsInto(

View File

@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
instruction->backend_config_ = proto.backend_config();
instruction->unique_id_ = proto.id();
if (proto.has_sharding()) {
TF_ASSIGN_OR_RETURN(const auto& sharding,

View File

@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names) {
bool uniquify_identifiers) {
if (is_entry) {
CHECK_EQ(nullptr, entry_computation_);
entry_computation_ = computation.get();
@ -73,30 +73,36 @@ HloComputation* HloModule::AddComputationInternal(
}
}
if (uniquify_names) {
if (uniquify_identifiers) {
computation->UniquifyName(&computation_name_uniquer_);
for (auto* instruction : computation->instructions()) {
instruction->UniquifyName(&instruction_name_uniquer_);
}
// Pick unique IDs for each instruction.
for (auto* instruction : computation->instructions()) {
instruction->SetUniqueId(NewUniqueInstructionId());
}
// Set unique id to this computation.
CHECK_NE(computation->root_instruction()->unique_id(), -1)
<< "Root has no valid id: " << computation->ToString();
computation->SetUniqueId(computation->root_instruction()->unique_id());
} else {
// Don't uniquify the names of the computation or instruction, but we must
// run the names through the uniquifiers to prevent future name collisions
// for computations and instructions created later.
// for computations and instructions created later. Also, set the
// next_unique_id_ to the one greater than the max unique id of any
// instruction (or the computation) to avoid ID collisions.
computation_name_uniquer_.GetUniqueName(computation->name());
for (auto* instruction : computation->instructions()) {
instruction_name_uniquer_.GetUniqueName(instruction->name());
next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
}
if (next_unique_id_ < computation->unique_id() + 1) {
next_unique_id_ = computation->unique_id() + 1;
}
}
// Pick unique IDs for each instruction.
for (auto* instruction : computation->instructions()) {
instruction->SetUniqueId(NewUniqueInstructionId());
}
// Set unique id to this computation.
CHECK_NE(computation->root_instruction()->unique_id(), -1)
<< "Root has no valid id: " << computation->ToString();
computation->SetUniqueId(computation->root_instruction()->unique_id());
computation->set_parent(this);
computations_.push_back(std::move(computation));
return computations_.back().get();
@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
HloComputation* HloModule::AddEntryComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
/*uniquify_names=*/true);
/*uniquify_identifiers=*/true);
}
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
HloComputation* HloModule::AddEmbeddedComputation(
std::unique_ptr<HloComputation> computation) {
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
/*uniquify_names=*/true);
/*uniquify_identifiers=*/true);
}
void HloModule::ReplaceComputations(
@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto, const HloModuleConfig& module_config) {
VLOG(2) << "CreateFromProto()";
XLA_VLOG_LINES(2, proto.DebugString());
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_RET_CHECK(proto.has_program_shape())
@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Don't uniquify names because we want names to be stable across
// serialization and deserialization.
module->AddComputationInternal(std::move(computation), is_entry,
/*uniquify_names=*/false);
/*uniquify_identifiers=*/false);
}
TF_RET_CHECK(module->entry_computation_ != nullptr);
// Because we didn't uniquify the names, double-check that the instruction and
// computation names are unique from the proto.
// Because we didn't uniquify the names or the ids, double-check that the
// instruction and computation names and ids are unique from the proto.
tensorflow::gtl::FlatSet<string> computation_names;
tensorflow::gtl::FlatSet<string> instruction_names;
tensorflow::gtl::FlatSet<int> computation_ids;
tensorflow::gtl::FlatSet<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
computation_names.insert(computation->name());
TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
<< "Computation id is not unique: " << computation->unique_id();
computation_ids.insert(computation->unique_id());
for (HloInstruction* instruction : computation->instructions()) {
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
<< "Instruction name is not unique: " << instruction->name();
instruction_names.insert(instruction->name());
TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
<< "Instruction id is not unique: " << instruction->unique_id();
instruction_ids.insert(instruction->unique_id());
}
}

View File

@ -253,7 +253,7 @@ class HloModule {
private:
HloComputation* AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
bool uniquify_names);
bool uniquify_identifiers);
const string name_;
HloModuleConfig config_;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@ -253,6 +254,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
// Verify that serializing then deserializing an HLO proto preserves the
// unique IDs of the instruction and module.
const string text =
R"(HloModule ReduceR3ToR2_module
add_F32.v3 {
lhs = f32[] parameter(0)
rhs = f32[] parameter(1)
ROOT add = f32[] add(lhs, rhs)
}
ENTRY ReduceR3ToR2.v3 {
input = f32[8,16,256]{2,1,0} parameter(0)
constant = f32[] constant(0)
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseHloString(text));
// Perform various transformations on the graph:
//
// * clone the reduction function
// * replace use of reduction function with the clone.
// * add a random instruction to the entry computation.
//
// This will create instruction and computation IDs which are interesting:
// not consecutive and not densely packed.
HloComputation* entry = module->entry_computation();
HloInstruction* root = entry->root_instruction();
HloComputation* reduction = root->to_apply();
HloComputation* reduction_clone =
module->AddEmbeddedComputation(reduction->Clone());
root->set_to_apply(reduction_clone);
TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
HloInstruction* negate = entry->AddInstruction(
HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
entry->set_root_instruction(negate);
// Schedule the transformed module, this verifies that the serialized schedule
// is robust against non-consecutive IDs as well (b/114712358).
auto size_fn = [](const BufferValue& buffer) {
return ShapeUtil::ByteSizeOf(buffer.shape());
};
HloMemoryScheduler scheduler(size_fn);
TF_ASSERT_OK(scheduler.Run(module.get()).status());
ASSERT_TRUE(module->has_schedule());
// Serialize and deserialize and verify that the instruction and computations
// unique ids are the same.
TF_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<HloModule> module_copy,
HloModule::CreateFromProto(module->ToProto(), module->config()));
// The module IDs should *not* be the same because module ids must be globally
// unique.
EXPECT_NE(module->unique_id(), module_copy->unique_id());
// Verify that the computations and instructions all have the same unique id.
auto computation_copy_it = module_copy->computations().begin();
for (const HloComputation* computation_orig : module->computations()) {
const HloComputation* computation_copy = *computation_copy_it++;
EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
<< absl::StrFormat(
"ID of original computation %s != ID of deserialized "
"computation %s: %d != %d",
computation_orig->name(), computation_copy->name(),
computation_orig->unique_id(), computation_copy->unique_id());
auto instruction_copy_it = computation_copy->instructions().begin();
for (const HloInstruction* instruction_orig :
computation_orig->instructions()) {
const HloInstruction* instruction_copy = *instruction_copy_it++;
EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
<< absl::StrFormat(
"ID of original instruction %s != ID of deserialized "
"instruction %s: %d != %d",
instruction_orig->name(), instruction_copy->name(),
instruction_orig->unique_id(), instruction_copy->unique_id());
}
}
// Verify that the next unique ID which the module would have handed out is
// greater than the unique id of any instruction.
int next_id = module_copy->NewUniqueInstructionId();
for (const HloComputation* computation : module_copy->computations()) {
for (const HloInstruction* instruction : computation->instructions()) {
EXPECT_GT(next_id, instruction->unique_id());
}
}
}
} // namespace
} // namespace xla