From fa4de2ca7aecb6b474c336b975b4563f4c7b7486 Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Sun, 7 Jul 2019 22:20:23 -0700 Subject: [PATCH] [XLA] Merge broadcasts of effective scalar constants within fusions into their users in graphical HLO dump. Just as we merge constants, merge broadcasts of effective scalar constants. I don't merge non-scalar constants because then you'd have to specify a broadcast dimension, and that gets complicated. I only do this within fusions because outside of fusions a broadcast of a constant is a "real" operation that you may want to see explicitly. PiperOrigin-RevId: 256907877 --- .../compiler/xla/service/hlo_graph_dumper.cc | 57 ++++++++++++------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 76b842fd582..1e12fa5367c 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -210,6 +210,12 @@ string HtmlLikeStringSanitize(absl::string_view s) { return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}}); } +bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) { + namespace m = match; + return instr->parent()->IsFusionComputation() && + Match(instr, m::Broadcast(m::ConstantEffectiveScalar())); +} + // Tries to generates a human-readable one-word description of the given // computation. // @@ -678,9 +684,11 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { string HloDotDumper::DumpRootTag() { const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); - // We didn't display constants as separate nodes; so if the root is a - // constant, we don't add root tag or edge for it. - if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { + // We didn't display constants or broadcasts of effective scalars within + // fusions as separate nodes; so if the root is a constant/broadcast of + // scalar, we don't add root tag or edge for it. + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || + IsFusedBroadcastOfConstantEffectiveScalar(from)) { return ""; } @@ -754,9 +762,10 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { } string HloDotDumper::DumpInstruction(const HloInstruction* instr) { - // We don't display constants as separate nodes; they're merged into their - // users. - if (instr->opcode() == HloOpcode::kConstant) { + // We don't display constants or broadcasts of effective scalar constants + // within fusions as separate nodes; they're merged into their users. + if (instr->opcode() == HloOpcode::kConstant || + IsFusedBroadcastOfConstantEffectiveScalar(instr)) { return ""; } // Skip this node if it's merged into its users. @@ -810,9 +819,11 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { - auto stringify_constant = [](const HloConstantInstruction* constant) { - const auto& shape = constant->shape(); - + // The constant's shape is a parameter because, in the case of a broadcasted + // scalar constant, we want to show the broadcasted shape, not the constant's + // scalar shape. + auto stringify_constant = [](const HloConstantInstruction* constant, + const Shape& shape) { // If the shape has a dimension of size zero, print it as e.g. // "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(), // enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which @@ -821,19 +832,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape())); } - // Print the literal value of constants with <= K elements. + // Print the literal value of constants with <= K elements. Note that we + // use `constant->shape()` rather than `shape`, because if `constant` is a + // scalar that's broadcasted into `shape`, we want to print the constant. optional elem_count; if (shape.IsArray()) { - elem_count = 1; - for (int64 dim : shape.dimensions()) { - *elem_count *= dim; - } + elem_count = ShapeUtil::ElementsIn(constant->shape()); } // Allow HloDotDumper to print HloInstruction reconstructed from HloProto // collected from profiling tools. Those constants may not have a valid // literal. if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) { - return constant->literal().ToString(); + return StrFormat("%s %s", shape.ToString(), + constant->literal().ToStringWithoutShape()); } // Otherwise, print e.g. "%constant.42 (s32[100])". @@ -843,17 +854,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( } else { constant_name = StrCat("constant ", constant->name()); } - return StrFormat("%s %s", constant_name, - ShapeUtil::HumanString(constant->shape())); + return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape)); }; std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); - const auto* constant_operand = DynCast(operand); optional operand_str; - if (constant_operand != nullptr) { - operand_str = stringify_constant(constant_operand); + if (const auto* constant_operand = + DynCast(operand)) { + operand_str = + stringify_constant(constant_operand, constant_operand->shape()); + } else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) { + operand_str = stringify_constant( + Cast(operand->operand(0)), operand->shape()); } else if (ShouldMergeIntoUsers(operand)) { // Special case: If the operand is a parameter to a fusion node and it // always has a constant value, display it like a regular constant. @@ -863,7 +877,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( if (operand->opcode() == HloOpcode::kParameter) { if (const HloConstantInstruction* constant = TryGetFusionParameterConstant(operand)) { - operand_str = stringify_constant(constant); + operand_str = stringify_constant(constant, constant->shape()); } else { operand_str = StrFormat("Parameter %d", operand->parameter_number()); } @@ -1179,6 +1193,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { from = GetNodeForEdge(from); if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || + IsFusedBroadcastOfConstantEffectiveScalar(from) || ShouldMergeIntoUsers(from)) { return; }