Preserve unique ids when serializing/deserializing HLO protos.
Re-assigning unique IDs broke serialization of HloSchedule, and keeping IDs stable improves the fidelity of the proto serialization. This change requires that instructions in HLO module protos have valid, module-scope-unique ids so change the XLA builder to hand out module-scope-unique ids. Previously, instruction ids were only unique in the computation scope. PiperOrigin-RevId: 212692339
This commit is contained in:
parent
3fb474713b
commit
5d1de24583
@ -231,6 +231,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
"//tensorflow/compiler/xla/service:hlo_profile_printer",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:regexp_internal",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/regexp.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -543,7 +544,13 @@ TEST(TFCompileTest, HloProfiling) {
|
|||||||
string hlo_profile_as_string =
|
string hlo_profile_as_string =
|
||||||
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
|
xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(),
|
||||||
/*clock_rate_ghz=*/1.0);
|
/*clock_rate_ghz=*/1.0);
|
||||||
VLOG(1) << "HLO profile string:\n" << hlo_profile_as_string;
|
VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string;
|
||||||
|
|
||||||
|
// Strip away identifier details from the profile string to avoid this test
|
||||||
|
// being a change detector for xla internals. Identifiers such as '%dot.0.7'
|
||||||
|
// just become '%dot'.
|
||||||
|
RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1");
|
||||||
|
VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string;
|
||||||
|
|
||||||
std::vector<string> hlo_profile_lines =
|
std::vector<string> hlo_profile_lines =
|
||||||
absl::StrSplit(hlo_profile_as_string, '\n');
|
absl::StrSplit(hlo_profile_as_string, '\n');
|
||||||
@ -551,16 +558,14 @@ TEST(TFCompileTest, HloProfiling) {
|
|||||||
auto header = HasSubstr("Execution profile for");
|
auto header = HasSubstr("Execution profile for");
|
||||||
auto total_cycles_profile_line = HasSubstr("[total]");
|
auto total_cycles_profile_line = HasSubstr("[total]");
|
||||||
auto dot_profile_line = HasSubstr(
|
auto dot_profile_line = HasSubstr(
|
||||||
"%dot.0.4 = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
"%dot = f32[2,2]{1,0} dot(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
|
||||||
"%arg1.0.1)");
|
|
||||||
auto add_profile_line = HasSubstr(
|
auto add_profile_line = HasSubstr(
|
||||||
"%add.0.6 = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0.0.0, f32[2,2]{1,0} "
|
"%add = f32[2,2]{1,0} add(f32[2,2]{1,0} %arg0, f32[2,2]{1,0} %arg1)");
|
||||||
"%arg1.0.1)");
|
|
||||||
auto tuple_profile_line = HasSubstr(
|
auto tuple_profile_line = HasSubstr(
|
||||||
"%tuple.0.8 = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} "
|
"%tuple = (f32[2,2]{1,0}, f32[2,2]{1,0}) tuple(f32[2,2]{1,0} %dot, "
|
||||||
"%dot.0.4, f32[2,2]{1,0} %add.0.6)");
|
"f32[2,2]{1,0} %add)");
|
||||||
auto arg0_profile_line = HasSubstr("%arg0.0.0 = f32[2,2]{1,0} parameter(0)");
|
auto arg0_profile_line = HasSubstr("%arg0 = f32[2,2]{1,0} parameter(0)");
|
||||||
auto arg1_profile_line = HasSubstr("%arg1.0.1 = f32[2,2]{1,0} parameter(1)");
|
auto arg1_profile_line = HasSubstr("%arg1 = f32[2,2]{1,0} parameter(1)");
|
||||||
|
|
||||||
EXPECT_THAT(hlo_profile_lines,
|
EXPECT_THAT(hlo_profile_lines,
|
||||||
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
IsSupersetOf({header, total_cycles_profile_line, dot_profile_line,
|
||||||
|
@ -604,10 +604,17 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) {
|
|||||||
auto instr1 = c1.instructions(j);
|
auto instr1 = c1.instructions(j);
|
||||||
auto instr2 = c2.instructions(j);
|
auto instr2 = c2.instructions(j);
|
||||||
instr1.clear_name();
|
instr1.clear_name();
|
||||||
|
instr1.clear_id();
|
||||||
|
instr1.clear_operand_ids();
|
||||||
instr2.clear_name();
|
instr2.clear_name();
|
||||||
// The names of instructions were uniquified by the XlaBuilder, the rest
|
instr2.clear_id();
|
||||||
// of the fields should be identical.
|
instr2.clear_operand_ids();
|
||||||
|
// The names of instructions were uniquified by the XlaBuilder and the
|
||||||
|
// unique ids may be different, the rest of the fields should be
|
||||||
|
// identical.
|
||||||
string str1, str2;
|
string str1, str2;
|
||||||
|
LOG(INFO) << "instr1 = " << instr1.DebugString();
|
||||||
|
LOG(INFO) << "instr2 = " << instr2.DebugString();
|
||||||
instr1.AppendPartialToString(&str1);
|
instr1.AppendPartialToString(&str1);
|
||||||
instr2.AppendPartialToString(&str2);
|
instr2.AppendPartialToString(&str2);
|
||||||
EXPECT_EQ(str1, str2);
|
EXPECT_EQ(str1, str2);
|
||||||
|
@ -134,11 +134,12 @@ XlaOp XlaBuilder::ReportErrorOrReturn(
|
|||||||
|
|
||||||
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
|
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64 root_id) const {
|
||||||
TF_RETURN_IF_ERROR(first_error_);
|
TF_RETURN_IF_ERROR(first_error_);
|
||||||
TF_RET_CHECK((root_id >= 0) && (root_id < instructions_.size()));
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* root_proto,
|
||||||
|
LookUpInstructionByHandle(root_id));
|
||||||
|
|
||||||
ProgramShape program_shape;
|
ProgramShape program_shape;
|
||||||
|
|
||||||
*program_shape.mutable_result() = instructions_[root_id].shape();
|
*program_shape.mutable_result() = root_proto->shape();
|
||||||
|
|
||||||
// Check that the parameter numbers are continuous from 0, and add parameter
|
// Check that the parameter numbers are continuous from 0, and add parameter
|
||||||
// shapes and names to the program shape.
|
// shapes and names to the program shape.
|
||||||
@ -181,9 +182,8 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
CHECK(op_handle < instructions_.size() && op_handle >= 0);
|
const HloInstructionProto& instr =
|
||||||
|
*(LookUpInstructionByHandle(op_handle).ValueOrDie());
|
||||||
const HloInstructionProto& instr = instructions_[op_handle];
|
|
||||||
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
|
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
|
||||||
switch (opcode) {
|
switch (opcode) {
|
||||||
default:
|
default:
|
||||||
@ -283,6 +283,7 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
|||||||
|
|
||||||
// Clear data held by this builder.
|
// Clear data held by this builder.
|
||||||
this->instructions_.clear();
|
this->instructions_.clear();
|
||||||
|
this->handle_to_index_.clear();
|
||||||
this->embedded_.clear();
|
this->embedded_.clear();
|
||||||
this->parameter_numbers_.clear();
|
this->parameter_numbers_.clear();
|
||||||
|
|
||||||
@ -2285,7 +2286,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
*program_shape->mutable_result() = root->shape();
|
*program_shape->mutable_result() = root->shape();
|
||||||
|
|
||||||
// We use std::set to keep the instruction ids in ascending order (which is
|
// We use std::set to keep the instruction ids in ascending order (which is
|
||||||
// also a valid denpendency order). The related ops will be added to the
|
// also a valid dependency order). The related ops will be added to the
|
||||||
// subgraph in the same order.
|
// subgraph in the same order.
|
||||||
std::set<int64> related_ops;
|
std::set<int64> related_ops;
|
||||||
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
|
tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
|
||||||
@ -2293,14 +2294,16 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
worklist.push(root->id());
|
worklist.push(root->id());
|
||||||
related_ops.insert(root->id());
|
related_ops.insert(root->id());
|
||||||
while (!worklist.empty()) {
|
while (!worklist.empty()) {
|
||||||
int64 node = worklist.front();
|
int64 handle = worklist.front();
|
||||||
worklist.pop();
|
worklist.pop();
|
||||||
for (int64 id : instructions_[node].operand_ids()) {
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_proto,
|
||||||
|
LookUpInstructionByHandle(handle));
|
||||||
|
for (int64 id : instr_proto->operand_ids()) {
|
||||||
if (related_ops.insert(id).second) {
|
if (related_ops.insert(id).second) {
|
||||||
worklist.push(id);
|
worklist.push(id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (int64 called_id : instructions_[node].called_computation_ids()) {
|
for (int64 called_id : instr_proto->called_computation_ids()) {
|
||||||
related_calls.insert(called_id);
|
related_calls.insert(called_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2308,7 +2311,9 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
|||||||
// Add related ops to the computation.
|
// Add related ops to the computation.
|
||||||
for (int64 id : related_ops) {
|
for (int64 id : related_ops) {
|
||||||
auto* instr = entry.add_instructions();
|
auto* instr = entry.add_instructions();
|
||||||
*instr = instructions_[id];
|
TF_ASSIGN_OR_RETURN(const HloInstructionProto* instr_src,
|
||||||
|
LookUpInstructionByHandle(id));
|
||||||
|
*instr = *instr_src;
|
||||||
// Ensures that the instruction names are unique among the graph.
|
// Ensures that the instruction names are unique among the graph.
|
||||||
const string& new_name =
|
const string& new_name =
|
||||||
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
|
StrCat(instr->name(), ".", entry.id(), ".", instr->id());
|
||||||
@ -2415,7 +2420,7 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
|||||||
absl::Span<const XlaOp> operands) {
|
absl::Span<const XlaOp> operands) {
|
||||||
TF_RETURN_IF_ERROR(first_error_);
|
TF_RETURN_IF_ERROR(first_error_);
|
||||||
|
|
||||||
const int64 handle = instructions_.size();
|
const int64 handle = GetUniqueId();
|
||||||
instr.set_id(handle);
|
instr.set_id(handle);
|
||||||
instr.set_opcode(HloOpcodeString(opcode));
|
instr.set_opcode(HloOpcodeString(opcode));
|
||||||
if (instr.name().empty()) {
|
if (instr.name().empty()) {
|
||||||
@ -2437,7 +2442,8 @@ StatusOr<XlaOp> XlaBuilder::AddInstruction(HloInstructionProto&& instr,
|
|||||||
*instr.mutable_sharding() = *sharding_;
|
*instr.mutable_sharding() = *sharding_;
|
||||||
}
|
}
|
||||||
|
|
||||||
instructions_.push_back(instr);
|
handle_to_index_[handle] = instructions_.size();
|
||||||
|
instructions_.push_back(std::move(instr));
|
||||||
|
|
||||||
XlaOp op(handle, this);
|
XlaOp op(handle, this);
|
||||||
return op;
|
return op;
|
||||||
@ -2467,10 +2473,16 @@ StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstruction(
|
|||||||
op.handle(), op.builder_->name(), this->name());
|
op.handle(), op.builder_->name(), this->name());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op.handle() >= instructions_.size() || op.handle() < 0) {
|
return LookUpInstructionByHandle(op.handle());
|
||||||
return InvalidArgument("no XlaOp value %d", op.handle());
|
}
|
||||||
|
|
||||||
|
StatusOr<const HloInstructionProto*> XlaBuilder::LookUpInstructionByHandle(
|
||||||
|
int64 handle) const {
|
||||||
|
auto it = handle_to_index_.find(handle);
|
||||||
|
if (it == handle_to_index_.end()) {
|
||||||
|
return InvalidArgument("No XlaOp with handle %d", handle);
|
||||||
}
|
}
|
||||||
return &instructions_[op.handle()];
|
return &instructions_[it->second];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
// Enqueues a "retrieve parameter value" instruction for a parameter that was
|
||||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/lib/gtl/flatmap.h"
|
||||||
#include "tensorflow/core/lib/gtl/flatset.h"
|
#include "tensorflow/core/lib/gtl/flatset.h"
|
||||||
#include "tensorflow/core/platform/macros.h"
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/stacktrace.h"
|
#include "tensorflow/core/platform/stacktrace.h"
|
||||||
@ -955,6 +956,8 @@ class XlaBuilder {
|
|||||||
HloInstructionProto* instr);
|
HloInstructionProto* instr);
|
||||||
|
|
||||||
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
|
StatusOr<const HloInstructionProto*> LookUpInstruction(const XlaOp& op) const;
|
||||||
|
StatusOr<const HloInstructionProto*> LookUpInstructionByHandle(
|
||||||
|
int64 handle) const;
|
||||||
|
|
||||||
// Internal helper method that does the building for an arbitrary unary op.
|
// Internal helper method that does the building for an arbitrary unary op.
|
||||||
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
|
XlaOp UnaryOp(HloOpcode unop, const XlaOp& operand);
|
||||||
@ -1024,6 +1027,10 @@ class XlaBuilder {
|
|||||||
// The instructions of this computation.
|
// The instructions of this computation.
|
||||||
std::vector<HloInstructionProto> instructions_;
|
std::vector<HloInstructionProto> instructions_;
|
||||||
|
|
||||||
|
// A map from XlaOp::Handle to the index in the instructions_ vector where the
|
||||||
|
// instruction is held.
|
||||||
|
tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
|
||||||
|
|
||||||
// The embedded computations used by this computation. Each computation was
|
// The embedded computations used by this computation. Each computation was
|
||||||
// the entry computation of some XlaComputation, the key is the unique id of
|
// the entry computation of some XlaComputation, the key is the unique id of
|
||||||
// that XlaComputation.
|
// that XlaComputation.
|
||||||
|
@ -1963,6 +1963,7 @@ tf_cc_test(
|
|||||||
deps = [
|
deps = [
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_matchers",
|
":hlo_matchers",
|
||||||
|
":hlo_memory_scheduler",
|
||||||
":hlo_parser",
|
":hlo_parser",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
@ -562,9 +562,11 @@ HloComputation::CreateFromProto(
|
|||||||
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
return to_proto_id[a.get()] < to_proto_id[b.get()];
|
||||||
});
|
});
|
||||||
|
|
||||||
return absl::WrapUnique(new HloComputation(proto.name(), parameter_count,
|
auto computation = absl::WrapUnique(
|
||||||
&instructions, root,
|
new HloComputation(proto.name(), parameter_count, &instructions, root,
|
||||||
/*fusion_instruction=*/nullptr));
|
/*fusion_instruction=*/nullptr));
|
||||||
|
computation->unique_id_ = proto.id();
|
||||||
|
return std::move(computation);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloComputation::FuseInstructionsInto(
|
void HloComputation::FuseInstructionsInto(
|
||||||
|
@ -505,6 +505,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
|||||||
instruction->SetAndSanitizeName(proto.name());
|
instruction->SetAndSanitizeName(proto.name());
|
||||||
instruction->metadata_ = proto.metadata();
|
instruction->metadata_ = proto.metadata();
|
||||||
instruction->backend_config_ = proto.backend_config();
|
instruction->backend_config_ = proto.backend_config();
|
||||||
|
instruction->unique_id_ = proto.id();
|
||||||
|
|
||||||
if (proto.has_sharding()) {
|
if (proto.has_sharding()) {
|
||||||
TF_ASSIGN_OR_RETURN(const auto& sharding,
|
TF_ASSIGN_OR_RETURN(const auto& sharding,
|
||||||
|
@ -60,7 +60,7 @@ Status HloModule::set_schedule(HloSchedule schedule) {
|
|||||||
|
|
||||||
HloComputation* HloModule::AddComputationInternal(
|
HloComputation* HloModule::AddComputationInternal(
|
||||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||||
bool uniquify_names) {
|
bool uniquify_identifiers) {
|
||||||
if (is_entry) {
|
if (is_entry) {
|
||||||
CHECK_EQ(nullptr, entry_computation_);
|
CHECK_EQ(nullptr, entry_computation_);
|
||||||
entry_computation_ = computation.get();
|
entry_computation_ = computation.get();
|
||||||
@ -73,20 +73,11 @@ HloComputation* HloModule::AddComputationInternal(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (uniquify_names) {
|
if (uniquify_identifiers) {
|
||||||
computation->UniquifyName(&computation_name_uniquer_);
|
computation->UniquifyName(&computation_name_uniquer_);
|
||||||
for (auto* instruction : computation->instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
instruction->UniquifyName(&instruction_name_uniquer_);
|
instruction->UniquifyName(&instruction_name_uniquer_);
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
// Don't uniquify the names of the computation or instruction, but we must
|
|
||||||
// run the names through the uniquifiers to prevent future name collisions
|
|
||||||
// for computations and instructions created later.
|
|
||||||
computation_name_uniquer_.GetUniqueName(computation->name());
|
|
||||||
for (auto* instruction : computation->instructions()) {
|
|
||||||
instruction_name_uniquer_.GetUniqueName(instruction->name());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pick unique IDs for each instruction.
|
// Pick unique IDs for each instruction.
|
||||||
for (auto* instruction : computation->instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
@ -96,6 +87,21 @@ HloComputation* HloModule::AddComputationInternal(
|
|||||||
CHECK_NE(computation->root_instruction()->unique_id(), -1)
|
CHECK_NE(computation->root_instruction()->unique_id(), -1)
|
||||||
<< "Root has no valid id: " << computation->ToString();
|
<< "Root has no valid id: " << computation->ToString();
|
||||||
computation->SetUniqueId(computation->root_instruction()->unique_id());
|
computation->SetUniqueId(computation->root_instruction()->unique_id());
|
||||||
|
} else {
|
||||||
|
// Don't uniquify the names of the computation or instruction, but we must
|
||||||
|
// run the names through the uniquifiers to prevent future name collisions
|
||||||
|
// for computations and instructions created later. Also, set the
|
||||||
|
// next_unique_id_ to the one greater than the max unique id of any
|
||||||
|
// instruction (or the computation) to avoid ID collisions.
|
||||||
|
computation_name_uniquer_.GetUniqueName(computation->name());
|
||||||
|
for (auto* instruction : computation->instructions()) {
|
||||||
|
instruction_name_uniquer_.GetUniqueName(instruction->name());
|
||||||
|
next_unique_id_ = std::max(next_unique_id_, instruction->unique_id() + 1);
|
||||||
|
}
|
||||||
|
if (next_unique_id_ < computation->unique_id() + 1) {
|
||||||
|
next_unique_id_ = computation->unique_id() + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
computation->set_parent(this);
|
computation->set_parent(this);
|
||||||
computations_.push_back(std::move(computation));
|
computations_.push_back(std::move(computation));
|
||||||
@ -105,7 +111,7 @@ HloComputation* HloModule::AddComputationInternal(
|
|||||||
HloComputation* HloModule::AddEntryComputation(
|
HloComputation* HloModule::AddEntryComputation(
|
||||||
std::unique_ptr<HloComputation> computation) {
|
std::unique_ptr<HloComputation> computation) {
|
||||||
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
|
return AddComputationInternal(std::move(computation), /*is_entry=*/true,
|
||||||
/*uniquify_names=*/true);
|
/*uniquify_identifiers=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
||||||
@ -122,7 +128,7 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) {
|
|||||||
HloComputation* HloModule::AddEmbeddedComputation(
|
HloComputation* HloModule::AddEmbeddedComputation(
|
||||||
std::unique_ptr<HloComputation> computation) {
|
std::unique_ptr<HloComputation> computation) {
|
||||||
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
|
return AddComputationInternal(std::move(computation), /*is_entry=*/false,
|
||||||
/*uniquify_names=*/true);
|
/*uniquify_identifiers=*/true);
|
||||||
}
|
}
|
||||||
|
|
||||||
void HloModule::ReplaceComputations(
|
void HloModule::ReplaceComputations(
|
||||||
@ -249,6 +255,9 @@ HloModuleProto HloModule::ToProto() const {
|
|||||||
/* static */
|
/* static */
|
||||||
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||||
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
const HloModuleProto& proto, const HloModuleConfig& module_config) {
|
||||||
|
VLOG(2) << "CreateFromProto()";
|
||||||
|
XLA_VLOG_LINES(2, proto.DebugString());
|
||||||
|
|
||||||
// The ProgramShape in the passed in module config must match the shapes of
|
// The ProgramShape in the passed in module config must match the shapes of
|
||||||
// the entry parameters and root.
|
// the entry parameters and root.
|
||||||
TF_RET_CHECK(proto.has_program_shape())
|
TF_RET_CHECK(proto.has_program_shape())
|
||||||
@ -312,22 +321,32 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||||||
// Don't uniquify names because we want names to be stable across
|
// Don't uniquify names because we want names to be stable across
|
||||||
// serialization and deserialization.
|
// serialization and deserialization.
|
||||||
module->AddComputationInternal(std::move(computation), is_entry,
|
module->AddComputationInternal(std::move(computation), is_entry,
|
||||||
/*uniquify_names=*/false);
|
/*uniquify_identifiers=*/false);
|
||||||
}
|
}
|
||||||
TF_RET_CHECK(module->entry_computation_ != nullptr);
|
TF_RET_CHECK(module->entry_computation_ != nullptr);
|
||||||
|
|
||||||
// Because we didn't uniquify the names, double-check that the instruction and
|
// Because we didn't uniquify the names or the ids, double-check that the
|
||||||
// computation names are unique from the proto.
|
// instruction and computation names and ids are unique from the proto.
|
||||||
tensorflow::gtl::FlatSet<string> computation_names;
|
tensorflow::gtl::FlatSet<string> computation_names;
|
||||||
tensorflow::gtl::FlatSet<string> instruction_names;
|
tensorflow::gtl::FlatSet<string> instruction_names;
|
||||||
|
tensorflow::gtl::FlatSet<int> computation_ids;
|
||||||
|
tensorflow::gtl::FlatSet<int> instruction_ids;
|
||||||
for (HloComputation* computation : module->computations()) {
|
for (HloComputation* computation : module->computations()) {
|
||||||
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
|
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
|
||||||
<< "Computation name is not unique: " << computation->name();
|
<< "Computation name is not unique: " << computation->name();
|
||||||
computation_names.insert(computation->name());
|
computation_names.insert(computation->name());
|
||||||
|
|
||||||
|
TF_RET_CHECK(!ContainsKey(computation_ids, computation->unique_id()))
|
||||||
|
<< "Computation id is not unique: " << computation->unique_id();
|
||||||
|
computation_ids.insert(computation->unique_id());
|
||||||
for (HloInstruction* instruction : computation->instructions()) {
|
for (HloInstruction* instruction : computation->instructions()) {
|
||||||
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
|
TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name()))
|
||||||
<< "Instruction name is not unique: " << instruction->name();
|
<< "Instruction name is not unique: " << instruction->name();
|
||||||
instruction_names.insert(instruction->name());
|
instruction_names.insert(instruction->name());
|
||||||
|
|
||||||
|
TF_RET_CHECK(!ContainsKey(instruction_ids, instruction->unique_id()))
|
||||||
|
<< "Instruction id is not unique: " << instruction->unique_id();
|
||||||
|
instruction_ids.insert(instruction->unique_id());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ class HloModule {
|
|||||||
private:
|
private:
|
||||||
HloComputation* AddComputationInternal(
|
HloComputation* AddComputationInternal(
|
||||||
std::unique_ptr<HloComputation> computation, bool is_entry,
|
std::unique_ptr<HloComputation> computation, bool is_entry,
|
||||||
bool uniquify_names);
|
bool uniquify_identifiers);
|
||||||
|
|
||||||
const string name_;
|
const string name_;
|
||||||
HloModuleConfig config_;
|
HloModuleConfig config_;
|
||||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||||
#include "tensorflow/compiler/xla/shape_util.h"
|
#include "tensorflow/compiler/xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||||
@ -253,6 +254,99 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
|
|||||||
op::Broadcast(), op::Multiply(), op::Add()));
|
op::Broadcast(), op::Multiply(), op::Add()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(HloModuleTest, ProtoSerializationPreservesIds) {
|
||||||
|
// Verify that serializing then deserializing an HLO proto preserves the
|
||||||
|
// unique IDs of the instruction and module.
|
||||||
|
const string text =
|
||||||
|
R"(HloModule ReduceR3ToR2_module
|
||||||
|
|
||||||
|
add_F32.v3 {
|
||||||
|
lhs = f32[] parameter(0)
|
||||||
|
rhs = f32[] parameter(1)
|
||||||
|
ROOT add = f32[] add(lhs, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
ENTRY ReduceR3ToR2.v3 {
|
||||||
|
input = f32[8,16,256]{2,1,0} parameter(0)
|
||||||
|
constant = f32[] constant(0)
|
||||||
|
ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
|
||||||
|
}
|
||||||
|
)";
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||||
|
ParseHloString(text));
|
||||||
|
|
||||||
|
// Perform various transformations on the graph:
|
||||||
|
//
|
||||||
|
// * clone the reduction function
|
||||||
|
// * replace use of reduction function with the clone.
|
||||||
|
// * add a random instruction to the entry computation.
|
||||||
|
//
|
||||||
|
// This will create instruction and computation IDs which are interesting:
|
||||||
|
// not consecutive and not densely packed.
|
||||||
|
HloComputation* entry = module->entry_computation();
|
||||||
|
HloInstruction* root = entry->root_instruction();
|
||||||
|
HloComputation* reduction = root->to_apply();
|
||||||
|
HloComputation* reduction_clone =
|
||||||
|
module->AddEmbeddedComputation(reduction->Clone());
|
||||||
|
root->set_to_apply(reduction_clone);
|
||||||
|
TF_ASSERT_OK(module->RemoveEmbeddedComputation(reduction));
|
||||||
|
HloInstruction* negate = entry->AddInstruction(
|
||||||
|
HloInstruction::CreateUnary(root->shape(), HloOpcode::kNegate, root));
|
||||||
|
entry->set_root_instruction(negate);
|
||||||
|
|
||||||
|
// Schedule the transformed module, this verifies that the serialized schedule
|
||||||
|
// is robust against non-consecutive IDs as well (b/114712358).
|
||||||
|
auto size_fn = [](const BufferValue& buffer) {
|
||||||
|
return ShapeUtil::ByteSizeOf(buffer.shape());
|
||||||
|
};
|
||||||
|
HloMemoryScheduler scheduler(size_fn);
|
||||||
|
TF_ASSERT_OK(scheduler.Run(module.get()).status());
|
||||||
|
ASSERT_TRUE(module->has_schedule());
|
||||||
|
|
||||||
|
// Serialize and deserialize and verify that the instruction and computations
|
||||||
|
// unique ids are the same.
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
std::unique_ptr<HloModule> module_copy,
|
||||||
|
HloModule::CreateFromProto(module->ToProto(), module->config()));
|
||||||
|
|
||||||
|
// The module IDs should *not* be the same because module ids must be globally
|
||||||
|
// unique.
|
||||||
|
EXPECT_NE(module->unique_id(), module_copy->unique_id());
|
||||||
|
|
||||||
|
// Verify that the computations and instructions all have the same unique id.
|
||||||
|
auto computation_copy_it = module_copy->computations().begin();
|
||||||
|
for (const HloComputation* computation_orig : module->computations()) {
|
||||||
|
const HloComputation* computation_copy = *computation_copy_it++;
|
||||||
|
EXPECT_EQ(computation_orig->unique_id(), computation_copy->unique_id())
|
||||||
|
<< absl::StrFormat(
|
||||||
|
"ID of original computation %s != ID of deserialized "
|
||||||
|
"computation %s: %d != %d",
|
||||||
|
computation_orig->name(), computation_copy->name(),
|
||||||
|
computation_orig->unique_id(), computation_copy->unique_id());
|
||||||
|
|
||||||
|
auto instruction_copy_it = computation_copy->instructions().begin();
|
||||||
|
for (const HloInstruction* instruction_orig :
|
||||||
|
computation_orig->instructions()) {
|
||||||
|
const HloInstruction* instruction_copy = *instruction_copy_it++;
|
||||||
|
EXPECT_EQ(instruction_orig->unique_id(), instruction_copy->unique_id())
|
||||||
|
<< absl::StrFormat(
|
||||||
|
"ID of original instruction %s != ID of deserialized "
|
||||||
|
"instruction %s: %d != %d",
|
||||||
|
instruction_orig->name(), instruction_copy->name(),
|
||||||
|
instruction_orig->unique_id(), instruction_copy->unique_id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the next unique ID which the module would have handed out is
|
||||||
|
// greater than the unique id of any instruction.
|
||||||
|
int next_id = module_copy->NewUniqueInstructionId();
|
||||||
|
for (const HloComputation* computation : module_copy->computations()) {
|
||||||
|
for (const HloInstruction* instruction : computation->instructions()) {
|
||||||
|
EXPECT_GT(next_id, instruction->unique_id());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user