Revamp handling of subcomputations in HLO graph dumper.

Before, we relied on a hacky heuristic -- "recurse into nested fusion
nodes" -- that didn't work for the case when e.g. a fusion node was
nested inside a while loop.

This change also adds a (very basic) testcase for the HLO graph dumper.

PiperOrigin-RevId: 169731958
This commit is contained in:
Justin Lebar 2017-09-22 13:46:29 -07:00 committed by TensorFlower Gardener
parent 5893f926e8
commit 83066d45ee
3 changed files with 189 additions and 47 deletions

View File

@ -1968,6 +1968,20 @@ cc_library(
alwayslink = 1,
)
tf_cc_test(
name = "hlo_graph_dumper_test",
srcs = ["hlo_graph_dumper_test.cc"],
deps = [
":hlo",
":hlo_graph_dumper",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/compiler/xla/tests:test_utils",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
"//tensorflow/core:lib",
],
)
cc_library(
name = "transpose_folding",
srcs = ["transpose_folding.cc"],

View File

@ -340,11 +340,8 @@ class HloDotDumper {
string Header();
string Footer();
// Maps HloComputations we should dump to their parent instruction in the
// outer computation.
std::unordered_map<const HloComputation*, const HloInstruction*>
SubcomputationsToDump();
bool ShouldShowSubcomputation(const HloComputation* subcomp);
bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
string DumpSubcomputation(const HloComputation* subcomp,
const HloInstruction* parent_instr);
string DumpComputation(const HloComputation* comp);
@ -401,11 +398,6 @@ class HloDotDumper {
string HloDotDumper::Dump() {
string body;
for (const auto& kv : SubcomputationsToDump()) {
const HloComputation* subcomp = kv.first;
const HloInstruction* parent = kv.second;
StrAppend(&body, DumpSubcomputation(subcomp, parent));
}
StrAppend(&body, DumpComputation(computation_));
StrAppend(&body, DumpRootTag());
@ -525,33 +517,36 @@ stylesheet="
string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
std::unordered_map<const HloComputation*, const HloInstruction*>
HloDotDumper::SubcomputationsToDump() {
// Dump the subcomputations of each instruction that's shown and doesn't have
// its operands omitted. If an instruction has just one subcomputation and
// it's trivial, omit it: We'll display that subcomputation inlined into the
// instruction's node when we draw it.
std::unordered_map<const HloComputation*, const HloInstruction*> to_dump;
for (const auto& instr : computation_->instructions()) {
if (!filter_.Show(instr.get()) ||
filter_.SomeOrAllOperandsOmitted(instr.get())) {
continue;
}
if (instr->opcode() == HloOpcode::kFusion) {
to_dump[instr->fused_instructions_computation()] = instr.get();
}
bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
return ShouldShowSubcomputation(instr->fused_instructions_computation());
}
for (const HloComputation* comp : instr->called_computations()) {
if (!MatchTrivialComputation(comp)) {
to_dump[comp] = instr.get();
}
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
if (subcomp->IsFusionComputation()) {
const HloInstruction* fusion = subcomp->FusionInstruction();
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
return false;
}
}
return to_dump;
// Don't show trivial subcomputations on non-fusion nodes -- these are inlined
// into the graph.
if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
return false;
}
// Show the subcomputation if we're showing any of its members.
return std::any_of(computation_->instructions().begin(),
computation_->instructions().end(),
[&](const std::unique_ptr<HloInstruction>& instr) {
return filter_.Show(instr.get());
});
}
string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
const HloInstruction* parent_instr) {
VLOG(2) << "Dumping subcomputation " << subcomp->name();
const char* computation_fmt = R"(subgraph %s {
%s
label = <%s>;
@ -593,20 +588,10 @@ tooltip = " ";
string comp_body = DumpComputation(subcomp);
if (parent_instr->opcode() == HloOpcode::kFusion) {
// Dump any nested fusion nodes.
for (const auto& subcomp_instr : subcomp->instructions()) {
if (subcomp_instr->opcode() == HloOpcode::kFusion) {
StrAppend(
&comp_body,
DumpSubcomputation(subcomp_instr->fused_instructions_computation(),
subcomp_instr.get()));
}
}
} else {
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
// Add an edge from the subcomputation to its parent node. If subcomp
// belongs to a fusion node, it's drawn in place of the fusion instruction,
// so there's no need to link those.
if (parent_instr->opcode() != HloOpcode::kFusion) {
VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to "
<< parent_instr->name() << " as " << next_edge_id_;
edge_ids_.insert(
@ -631,6 +616,14 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
if (!filter_.Show(instr.get())) {
continue;
}
// Dump subcomputations within instr.
for (const HloComputation* subcomp : instr->called_computations()) {
if (ShouldShowSubcomputation(subcomp)) {
StrAppend(&g, DumpSubcomputation(subcomp, instr.get()));
}
}
StrAppend(&g, DumpInstruction(instr.get()));
}
return g;
@ -638,6 +631,14 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
string HloDotDumper::DumpRootTag() {
HloInstruction* from = computation_->root_instruction();
// Fusion nodes are expanded inline, so if root is an expanded fusion node,
// walk up the graph until we find a node that isn't.
while (from->opcode() == HloOpcode::kFusion &&
ShouldShowFusionSubcomputation(from)) {
from = from->fused_expression_root();
}
auto from_id = InstructionId(from);
if (!filter_.Show(from)) {
@ -678,7 +679,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
// Omit the fusion node if its subcomputation is drawn, since the
// subcomputation will be drawn inline.
if (instr->opcode() == HloOpcode::kFusion &&
filter_.ShowFusionSubcomputation(instr)) {
ShouldShowFusionSubcomputation(instr)) {
return "";
}
@ -937,7 +938,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
// Show the shape and layout of the instruction, unless it's an inlined fusion
// node -- there the shape and layout is present in the output node.
if (instr->opcode() != HloOpcode::kFusion ||
!filter_.ShowFusionSubcomputation(instr)) {
!ShouldShowFusionSubcomputation(instr)) {
string instr_shape = ShapeUtil::HumanString(instr->shape());
// Show layout of non-tuple shapes with more than one dimension.
@ -982,7 +983,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
// fusion node and the node's subcomputation is shown, we draw our edge
// starting at the fusion node's root instead of at the fusion node itself.
if (from->opcode() == HloOpcode::kFusion &&
filter_.ShowFusionSubcomputation(from)) {
ShouldShowFusionSubcomputation(from)) {
from = from->fused_expression_root();
}
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
@ -1147,6 +1148,11 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
}
}
// Traverse into instr's nested computations.
for (const HloComputation* computation : instr->called_computations()) {
worklist.push_back({computation->root_instruction(), depth + 1});
}
// Traverse into instr's users, unless:
//
// - there are a ton of them, in which case they're probably not

View File

@ -0,0 +1,122 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace xla {
namespace {
using ::tensorflow::strings::StrCat;
using ::testing::HasSubstr;
string TestName() {
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
}
class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
public:
string RenderGraph(const string& graph, GraphKind graph_kind,
const DebugOptions& debug_options) override {
return graph;
}
private:
string last_graph_;
};
XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits<int>::max());
TEST(HloGraphDumperTest, NestedFusion) {
HloComputation::Builder b("b");
// Build param0 + param1 + param2 + param3 + param4.
auto shape = ShapeUtil::MakeShape(F32, {10, 100});
std::vector<HloInstruction*> params;
for (int i = 0; i <= 4; ++i) {
params.push_back(b.AddInstruction(
HloInstruction::CreateParameter(i, shape, StrCat("param", i))));
}
std::vector<HloInstruction*> sums;
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, params[0], params[1])));
for (int i = 0; i <= 2; ++i) {
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
shape, HloOpcode::kAdd, sums[i], params[i + 2])));
}
HloModule m(TestName());
m.AddEntryComputation(b.Build());
HloComputation* root_computation = m.entry_computation();
// Fuse into fusion(param0 + param1 + param2 + param3 + param4).
auto* outer_fusion = root_computation->CreateFusionInstruction(
{sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop);
// Fusing invalidates the pointers in sums -- the instructions are cloned when
// they're moved to the new computation. Get the updated pointers to sums.
std::vector<HloInstruction*> fused_sums;
for (auto* instr : outer_fusion->fused_instructions_computation()
->MakeInstructionPostOrder()) {
if (instr->opcode() == HloOpcode::kAdd) {
fused_sums.push_back(instr);
}
}
// Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4).
auto* inner_fusion =
outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
{fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop);
// Generate the graph; all nodes should be present.
string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"",
DebugOptions());
for (const HloComputation* computation :
{root_computation, //
inner_fusion->fused_instructions_computation(),
outer_fusion->fused_instructions_computation()}) {
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
EXPECT_THAT(graph, HasSubstr(instruction->name()));
}
}
// Dump a neighborhood around one of the inner sum nodes. We don't really
// care that the outer nodes are omitted -- whether they are or not is based
// fiddly heuristics -- but we do care that the node we asked for is printed.
const HloInstruction* inner_sum = nullptr;
for (const std::unique_ptr<HloInstruction>& instruction :
inner_fusion->fused_instructions_computation()->instructions()) {
if (instruction->opcode() == HloOpcode::kAdd) {
inner_sum = instruction.get();
break;
}
}
ASSERT_NE(inner_sum, nullptr);
EXPECT_THAT(
hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1),
HasSubstr(inner_sum->name()));
}
} // anonymous namespace
} // namespace xla