[XLA] Merge broadcasts of effective scalar constants within fusions into their users in graphical HLO dump.

Just as we merge constants, merge broadcasts of effective scalar constants.

I don't merge non-scalar constants because then you'd have to specify a
broadcast dimension, and that gets complicated.

I only do this within fusions because outside of fusions a broadcast of a
constant is a "real" operation that you may want to see explicitly.

PiperOrigin-RevId: 256907877
This commit is contained in:
Justin Lebar 2019-07-07 22:20:23 -07:00 committed by TensorFlower Gardener
parent 780530fd75
commit fa4de2ca7a

View File

@ -210,6 +210,12 @@ string HtmlLikeStringSanitize(absl::string_view s) {
return absl::StrReplaceAll(s, {{"<", "&lt;"}, {">", "&gt;"}});
}
bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
namespace m = match;
return instr->parent()->IsFusionComputation() &&
Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
}
// Tries to generates a human-readable one-word description of the given
// computation.
//
@ -678,9 +684,11 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
string HloDotDumper::DumpRootTag() {
const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
// We didn't display constants as separate nodes; so if the root is a
// constant, we don't add root tag or edge for it.
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
// We didn't display constants or broadcasts of effective scalars within
// fusions as separate nodes; so if the root is a constant/broadcast of
// scalar, we don't add root tag or edge for it.
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
IsFusedBroadcastOfConstantEffectiveScalar(from)) {
return "";
}
@ -754,9 +762,10 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
}
string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
// We don't display constants as separate nodes; they're merged into their
// users.
if (instr->opcode() == HloOpcode::kConstant) {
// We don't display constants or broadcasts of effective scalar constants
// within fusions as separate nodes; they're merged into their users.
if (instr->opcode() == HloOpcode::kConstant ||
IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
return "";
}
// Skip this node if it's merged into its users.
@ -810,9 +819,11 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
string HloDotDumper::GetInstructionNodeInlinedOperands(
const HloInstruction* instr) {
auto stringify_constant = [](const HloConstantInstruction* constant) {
const auto& shape = constant->shape();
// The constant's shape is a parameter because, in the case of a broadcasted
// scalar constant, we want to show the broadcasted shape, not the constant's
// scalar shape.
auto stringify_constant = [](const HloConstantInstruction* constant,
const Shape& shape) {
// If the shape has a dimension of size zero, print it as e.g.
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
@ -821,19 +832,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
}
// Print the literal value of constants with <= K elements.
// Print the literal value of constants with <= K elements. Note that we
// use `constant->shape()` rather than `shape`, because if `constant` is a
// scalar that's broadcasted into `shape`, we want to print the constant.
optional<int64> elem_count;
if (shape.IsArray()) {
elem_count = 1;
for (int64 dim : shape.dimensions()) {
*elem_count *= dim;
}
elem_count = ShapeUtil::ElementsIn(constant->shape());
}
// Allow HloDotDumper to print HloInstruction reconstructed from HloProto
// collected from profiling tools. Those constants may not have a valid
// literal.
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
return constant->literal().ToString();
return StrFormat("%s %s", shape.ToString(),
constant->literal().ToStringWithoutShape());
}
// Otherwise, print e.g. "%constant.42 (s32[100])".
@ -843,17 +854,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
} else {
constant_name = StrCat("constant ", constant->name());
}
return StrFormat("%s %s", constant_name,
ShapeUtil::HumanString(constant->shape()));
return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
};
std::vector<string> lines;
for (int64 i = 0; i < instr->operand_count(); ++i) {
const HloInstruction* operand = instr->operand(i);
const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
optional<string> operand_str;
if (constant_operand != nullptr) {
operand_str = stringify_constant(constant_operand);
if (const auto* constant_operand =
DynCast<HloConstantInstruction>(operand)) {
operand_str =
stringify_constant(constant_operand, constant_operand->shape());
} else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
operand_str = stringify_constant(
Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
} else if (ShouldMergeIntoUsers(operand)) {
// Special case: If the operand is a parameter to a fusion node and it
// always has a constant value, display it like a regular constant.
@ -863,7 +877,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
if (operand->opcode() == HloOpcode::kParameter) {
if (const HloConstantInstruction* constant =
TryGetFusionParameterConstant(operand)) {
operand_str = stringify_constant(constant);
operand_str = stringify_constant(constant, constant->shape());
} else {
operand_str = StrFormat("Parameter %d", operand->parameter_number());
}
@ -1179,6 +1193,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
from = GetNodeForEdge(from);
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
IsFusedBroadcastOfConstantEffectiveScalar(from) ||
ShouldMergeIntoUsers(from)) {
return;
}