Show layouts in HLO graph dump.

Layouts are displayed as e.g. "f32[100,200]{0,1}".  But constants used
to be displayed as e.g. "f32[]{42}".  To avoid ambiguity, constants are
now displayed as e.g. "42 (f32[])".

Also gets rid of the xla_hlo_graph_layout flag, which is no longer
necessary since we're now showing layouts unconditionally.

PiperOrigin-RevId: 163753637
This commit is contained in:
Justin Lebar 2017-07-31 15:02:54 -07:00 committed by TensorFlower Gardener
parent 84c2757a66
commit 724884f1ca
4 changed files with 13 additions and 30 deletions

View File

@ -116,12 +116,6 @@ void AllocateFlags() {
flag_values->xla_hlo_graph_addresses(), flag_values->xla_hlo_graph_addresses(),
"With xla_generate_hlo_graph, show addresses of HLO ops in " "With xla_generate_hlo_graph, show addresses of HLO ops in "
"graph dump."), "graph dump."),
tensorflow::Flag(
"xla_hlo_graph_layout",
bool_setter_for(&DebugOptions::set_xla_hlo_graph_layout),
flag_values->xla_hlo_graph_layout(),
"With xla_generate_hlo_graph, show layout of HLO ops in "
"graph dump."),
tensorflow::Flag( tensorflow::Flag(
"xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(),
"With xla_generate_hlo_graph, dump the graphs into this path."), "With xla_generate_hlo_graph, dump the graphs into this path."),

View File

@ -312,12 +312,11 @@ optional<string> MatchTrivialComputation(const HloComputation* computation) {
class HloDotDumper { class HloDotDumper {
public: public:
HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label, HloDotDumper(const HloComputation* computation, tensorflow::StringPiece label,
bool show_addresses, bool show_layouts, bool show_addresses, const HloExecutionProfile* profile,
const HloExecutionProfile* profile, NodeFilter filter) NodeFilter filter)
: computation_(computation), : computation_(computation),
label_(label.ToString()), label_(label.ToString()),
show_addresses_(show_addresses), show_addresses_(show_addresses),
show_layouts_(show_layouts),
profile_(profile), profile_(profile),
filter_(std::move(filter)) {} filter_(std::move(filter)) {}
@ -364,7 +363,6 @@ class HloDotDumper {
const HloComputation* computation_; // never null const HloComputation* computation_; // never null
const string label_; // overall name for the graph const string label_; // overall name for the graph
const bool show_addresses_; const bool show_addresses_;
const bool show_layouts_;
const HloExecutionProfile* profile_; // may be null const HloExecutionProfile* profile_; // may be null
const NodeFilter filter_; const NodeFilter filter_;
@ -642,8 +640,8 @@ string HloDotDumper::GetInstructionNodeInlinedConstants(
if (ShapeUtil::IsEffectiveScalar(constant->shape())) { if (ShapeUtil::IsEffectiveScalar(constant->shape())) {
auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex( auto elem_idx = IndexUtil::LinearIndexToMultidimensionalIndex(
constant->shape(), /*linear_index=*/0); constant->shape(), /*linear_index=*/0);
return Printf("%s{%s}", ShapeUtil::HumanString(constant->shape()), return Printf("%s (%s)", constant->literal().GetAsString(elem_idx),
constant->literal().GetAsString(elem_idx)); ShapeUtil::HumanString(constant->shape()));
} }
if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) {
return constant->name(); return constant->name();
@ -659,7 +657,7 @@ string HloDotDumper::GetInstructionNodeInlinedConstants(
if (operand->opcode() != HloOpcode::kConstant) { if (operand->opcode() != HloOpcode::kConstant) {
return ""; return "";
} }
return stringify_constant(operand); return StrCat("<b>constant</b> ", stringify_constant(operand));
} }
std::vector<string> lines; std::vector<string> lines;
@ -827,6 +825,14 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
// shape to kMaxShapeLen characters. // shape to kMaxShapeLen characters.
constexpr int kMaxShapeLen = 64; constexpr int kMaxShapeLen = 64;
string instr_shape = ShapeUtil::HumanString(instr->shape()); string instr_shape = ShapeUtil::HumanString(instr->shape());
// Show layout of non-tuple shapes with more than one dimension.
if (LayoutUtil::HasLayout(instr->shape()) &&
instr->shape().dimensions_size() > 1 &&
!ShapeUtil::IsTuple(instr->shape())) {
StrAppend(&instr_shape, "{",
Join(instr->shape().layout().minor_to_major(), ","), "}");
}
if (instr_shape.length() > kMaxShapeLen) { if (instr_shape.length() > kMaxShapeLen) {
instr_shape = instr_shape =
StrCat(tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3), StrCat(tensorflow::StringPiece(instr_shape).substr(0, kMaxShapeLen - 3),
@ -837,17 +843,6 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
if (show_addresses_) { if (show_addresses_) {
lines.push_back(Printf("[%p]", instr)); lines.push_back(Printf("[%p]", instr));
} }
if (show_layouts_ && LayoutUtil::HasLayout(instr->shape())) {
string layout_str;
if (ShapeUtil::IsTuple(instr->shape())) {
// For tuples, emit the full shape because the layout of a tuple is not
// represented in a single Layout field.
layout_str = ShapeUtil::HumanStringWithLayout(instr->shape());
} else {
layout_str = Join(instr->shape().layout().minor_to_major(), ",");
}
lines.push_back(Printf("layout={%s}", layout_str));
}
if (profile_ != nullptr) { if (profile_ != nullptr) {
double hlo_cycles_executed = profile_->GetProfileResult(*instr); double hlo_cycles_executed = profile_->GetProfileResult(*instr);
double total_cycles_executed = double total_cycles_executed =
@ -1115,7 +1110,6 @@ string DumpGraph(const HloComputation& computation, const string& label,
graph = graph =
HloDotDumper(&computation, label, HloDotDumper(&computation, label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(), /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*show_layouts=*/debug_options.xla_hlo_graph_layout(),
hlo_execution_profile, NodeFilter()) hlo_execution_profile, NodeFilter())
.Dump(); .Dump();
graph_url = GetGraphRenderer()->RenderGraph( graph_url = GetGraphRenderer()->RenderGraph(
@ -1134,7 +1128,6 @@ string DumpNeighborhoodAround(const HloInstruction& node, int radius) {
string graph = string graph =
HloDotDumper(node.parent(), label, HloDotDumper(node.parent(), label,
/*show_addresses=*/debug_options.xla_hlo_graph_addresses(), /*show_addresses=*/debug_options.xla_hlo_graph_addresses(),
/*show_layouts=*/debug_options.xla_hlo_graph_layout(),
/*profile=*/nullptr, filter) /*profile=*/nullptr, filter)
.Dump(); .Dump();
return GetGraphRenderer()->RenderGraph( return GetGraphRenderer()->RenderGraph(

View File

@ -55,7 +55,6 @@ void RealMain(tensorflow::gtl::ArraySlice<char*> args) {
Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie(); Computation computation = client->LoadSnapshot(module).ConsumeValueOrDie();
DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags(); DebugOptions debug_options = legacy_flags::GetDebugOptionsFromFlags();
debug_options.set_xla_generate_hlo_graph(".*"); debug_options.set_xla_generate_hlo_graph(".*");
debug_options.set_xla_hlo_graph_layout(true);
ComputationStats stats = ComputationStats stats =
client->GetComputationStats(computation, debug_options) client->GetComputationStats(computation, debug_options)
.ConsumeValueOrDie(); .ConsumeValueOrDie();

View File

@ -49,9 +49,6 @@ message DebugOptions {
// Show addresses of HLO ops in graph dump. // Show addresses of HLO ops in graph dump.
bool xla_hlo_graph_addresses = 2; bool xla_hlo_graph_addresses = 2;
// Show layout of HLO ops in graph dump.
bool xla_hlo_graph_layout = 3;
// Path to dump HLO graphs to. // Path to dump HLO graphs to.
string xla_hlo_graph_path = 4; string xla_hlo_graph_path = 4;