[XLA] Merge broadcasts of effective scalar constants within fusions into their users in graphical HLO dump.
Just as we merge constants, merge broadcasts of effective scalar constants. I don't merge non-scalar constants because then you'd have to specify a broadcast dimension, and that gets complicated. I only do this within fusions because outside of fusions a broadcast of a constant is a "real" operation that you may want to see explicitly. PiperOrigin-RevId: 256907877
This commit is contained in:
parent
780530fd75
commit
fa4de2ca7a
@ -210,6 +210,12 @@ string HtmlLikeStringSanitize(absl::string_view s) {
|
||||
return absl::StrReplaceAll(s, {{"<", "<"}, {">", ">"}});
|
||||
}
|
||||
|
||||
bool IsFusedBroadcastOfConstantEffectiveScalar(const HloInstruction* instr) {
|
||||
namespace m = match;
|
||||
return instr->parent()->IsFusionComputation() &&
|
||||
Match(instr, m::Broadcast(m::ConstantEffectiveScalar()));
|
||||
}
|
||||
|
||||
// Tries to generates a human-readable one-word description of the given
|
||||
// computation.
|
||||
//
|
||||
@ -678,9 +684,11 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
|
||||
string HloDotDumper::DumpRootTag() {
|
||||
const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
|
||||
|
||||
// We didn't display constants as separate nodes; so if the root is a
|
||||
// constant, we don't add root tag or edge for it.
|
||||
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
|
||||
// We didn't display constants or broadcasts of effective scalars within
|
||||
// fusions as separate nodes; so if the root is a constant/broadcast of
|
||||
// scalar, we don't add root tag or edge for it.
|
||||
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
|
||||
IsFusedBroadcastOfConstantEffectiveScalar(from)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
@ -754,9 +762,10 @@ bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const {
|
||||
}
|
||||
|
||||
string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
|
||||
// We don't display constants as separate nodes; they're merged into their
|
||||
// users.
|
||||
if (instr->opcode() == HloOpcode::kConstant) {
|
||||
// We don't display constants or broadcasts of effective scalar constants
|
||||
// within fusions as separate nodes; they're merged into their users.
|
||||
if (instr->opcode() == HloOpcode::kConstant ||
|
||||
IsFusedBroadcastOfConstantEffectiveScalar(instr)) {
|
||||
return "";
|
||||
}
|
||||
// Skip this node if it's merged into its users.
|
||||
@ -810,9 +819,11 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
|
||||
|
||||
string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
const HloInstruction* instr) {
|
||||
auto stringify_constant = [](const HloConstantInstruction* constant) {
|
||||
const auto& shape = constant->shape();
|
||||
|
||||
// The constant's shape is a parameter because, in the case of a broadcasted
|
||||
// scalar constant, we want to show the broadcasted shape, not the constant's
|
||||
// scalar shape.
|
||||
auto stringify_constant = [](const HloConstantInstruction* constant,
|
||||
const Shape& shape) {
|
||||
// If the shape has a dimension of size zero, print it as e.g.
|
||||
// "{} (f32[42, 0, 10])". The alternative, calling Literal::ToString(),
|
||||
// enumerates all of its empty dimensions (e.g. "{ { {}, {} }, ..."), which
|
||||
@ -821,19 +832,19 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
return StrFormat("{} (%s)", ShapeUtil::HumanString(constant->shape()));
|
||||
}
|
||||
|
||||
// Print the literal value of constants with <= K elements.
|
||||
// Print the literal value of constants with <= K elements. Note that we
|
||||
// use `constant->shape()` rather than `shape`, because if `constant` is a
|
||||
// scalar that's broadcasted into `shape`, we want to print the constant.
|
||||
optional<int64> elem_count;
|
||||
if (shape.IsArray()) {
|
||||
elem_count = 1;
|
||||
for (int64 dim : shape.dimensions()) {
|
||||
*elem_count *= dim;
|
||||
}
|
||||
elem_count = ShapeUtil::ElementsIn(constant->shape());
|
||||
}
|
||||
// Allow HloDotDumper to print HloInstruction reconstructed from HloProto
|
||||
// collected from profiling tools. Those constants may not have a valid
|
||||
// literal.
|
||||
if (elem_count.has_value() && *elem_count <= 8 && constant->HasLiteral()) {
|
||||
return constant->literal().ToString();
|
||||
return StrFormat("%s %s", shape.ToString(),
|
||||
constant->literal().ToStringWithoutShape());
|
||||
}
|
||||
|
||||
// Otherwise, print e.g. "%constant.42 (s32[100])".
|
||||
@ -843,17 +854,20 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
} else {
|
||||
constant_name = StrCat("constant ", constant->name());
|
||||
}
|
||||
return StrFormat("%s %s", constant_name,
|
||||
ShapeUtil::HumanString(constant->shape()));
|
||||
return StrFormat("%s %s", constant_name, ShapeUtil::HumanString(shape));
|
||||
};
|
||||
|
||||
std::vector<string> lines;
|
||||
for (int64 i = 0; i < instr->operand_count(); ++i) {
|
||||
const HloInstruction* operand = instr->operand(i);
|
||||
const auto* constant_operand = DynCast<HloConstantInstruction>(operand);
|
||||
optional<string> operand_str;
|
||||
if (constant_operand != nullptr) {
|
||||
operand_str = stringify_constant(constant_operand);
|
||||
if (const auto* constant_operand =
|
||||
DynCast<HloConstantInstruction>(operand)) {
|
||||
operand_str =
|
||||
stringify_constant(constant_operand, constant_operand->shape());
|
||||
} else if (IsFusedBroadcastOfConstantEffectiveScalar(operand)) {
|
||||
operand_str = stringify_constant(
|
||||
Cast<HloConstantInstruction>(operand->operand(0)), operand->shape());
|
||||
} else if (ShouldMergeIntoUsers(operand)) {
|
||||
// Special case: If the operand is a parameter to a fusion node and it
|
||||
// always has a constant value, display it like a regular constant.
|
||||
@ -863,7 +877,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands(
|
||||
if (operand->opcode() == HloOpcode::kParameter) {
|
||||
if (const HloConstantInstruction* constant =
|
||||
TryGetFusionParameterConstant(operand)) {
|
||||
operand_str = stringify_constant(constant);
|
||||
operand_str = stringify_constant(constant, constant->shape());
|
||||
} else {
|
||||
operand_str = StrFormat("Parameter %d", operand->parameter_number());
|
||||
}
|
||||
@ -1179,6 +1193,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
|
||||
from = GetNodeForEdge(from);
|
||||
|
||||
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant ||
|
||||
IsFusedBroadcastOfConstantEffectiveScalar(from) ||
|
||||
ShouldMergeIntoUsers(from)) {
|
||||
return;
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user