Revamp handling of subcomputations in HLO graph dumper.
Before, we relied on a hacky heuristic -- "recurse into nested fusion nodes" -- that didn't work for the case when e.g. a fusion node was nested inside a while loop. This change also adds a (very basic) testcase for the HLO graph dumper. PiperOrigin-RevId: 169731958
This commit is contained in:
parent
5893f926e8
commit
83066d45ee
@ -1968,6 +1968,20 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hlo_graph_dumper_test",
|
||||
srcs = ["hlo_graph_dumper_test.cc"],
|
||||
deps = [
|
||||
":hlo",
|
||||
":hlo_graph_dumper",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/compiler/xla/tests:test_utils",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "transpose_folding",
|
||||
srcs = ["transpose_folding.cc"],
|
||||
|
@ -340,11 +340,8 @@ class HloDotDumper {
|
||||
string Header();
|
||||
string Footer();
|
||||
|
||||
// Maps HloComputations we should dump to their parent instruction in the
|
||||
// outer computation.
|
||||
std::unordered_map<const HloComputation*, const HloInstruction*>
|
||||
SubcomputationsToDump();
|
||||
|
||||
bool ShouldShowSubcomputation(const HloComputation* subcomp);
|
||||
bool ShouldShowFusionSubcomputation(const HloInstruction* instr);
|
||||
string DumpSubcomputation(const HloComputation* subcomp,
|
||||
const HloInstruction* parent_instr);
|
||||
string DumpComputation(const HloComputation* comp);
|
||||
@ -401,11 +398,6 @@ class HloDotDumper {
|
||||
|
||||
string HloDotDumper::Dump() {
|
||||
string body;
|
||||
for (const auto& kv : SubcomputationsToDump()) {
|
||||
const HloComputation* subcomp = kv.first;
|
||||
const HloInstruction* parent = kv.second;
|
||||
StrAppend(&body, DumpSubcomputation(subcomp, parent));
|
||||
}
|
||||
StrAppend(&body, DumpComputation(computation_));
|
||||
StrAppend(&body, DumpRootTag());
|
||||
|
||||
@ -525,33 +517,36 @@ stylesheet="
|
||||
|
||||
string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); }
|
||||
|
||||
std::unordered_map<const HloComputation*, const HloInstruction*>
|
||||
HloDotDumper::SubcomputationsToDump() {
|
||||
// Dump the subcomputations of each instruction that's shown and doesn't have
|
||||
// its operands omitted. If an instruction has just one subcomputation and
|
||||
// it's trivial, omit it: We'll display that subcomputation inlined into the
|
||||
// instruction's node when we draw it.
|
||||
std::unordered_map<const HloComputation*, const HloInstruction*> to_dump;
|
||||
for (const auto& instr : computation_->instructions()) {
|
||||
if (!filter_.Show(instr.get()) ||
|
||||
filter_.SomeOrAllOperandsOmitted(instr.get())) {
|
||||
continue;
|
||||
}
|
||||
if (instr->opcode() == HloOpcode::kFusion) {
|
||||
to_dump[instr->fused_instructions_computation()] = instr.get();
|
||||
}
|
||||
bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) {
|
||||
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
|
||||
return ShouldShowSubcomputation(instr->fused_instructions_computation());
|
||||
}
|
||||
|
||||
for (const HloComputation* comp : instr->called_computations()) {
|
||||
if (!MatchTrivialComputation(comp)) {
|
||||
to_dump[comp] = instr.get();
|
||||
}
|
||||
bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) {
|
||||
if (subcomp->IsFusionComputation()) {
|
||||
const HloInstruction* fusion = subcomp->FusionInstruction();
|
||||
if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return to_dump;
|
||||
|
||||
// Don't show trivial subcomputations on non-fusion nodes -- these are inlined
|
||||
// into the graph.
|
||||
if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Show the subcomputation if we're showing any of its members.
|
||||
return std::any_of(computation_->instructions().begin(),
|
||||
computation_->instructions().end(),
|
||||
[&](const std::unique_ptr<HloInstruction>& instr) {
|
||||
return filter_.Show(instr.get());
|
||||
});
|
||||
}
|
||||
|
||||
string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp,
|
||||
const HloInstruction* parent_instr) {
|
||||
VLOG(2) << "Dumping subcomputation " << subcomp->name();
|
||||
const char* computation_fmt = R"(subgraph %s {
|
||||
%s
|
||||
label = <%s>;
|
||||
@ -593,20 +588,10 @@ tooltip = " ";
|
||||
|
||||
string comp_body = DumpComputation(subcomp);
|
||||
|
||||
if (parent_instr->opcode() == HloOpcode::kFusion) {
|
||||
// Dump any nested fusion nodes.
|
||||
for (const auto& subcomp_instr : subcomp->instructions()) {
|
||||
if (subcomp_instr->opcode() == HloOpcode::kFusion) {
|
||||
StrAppend(
|
||||
&comp_body,
|
||||
DumpSubcomputation(subcomp_instr->fused_instructions_computation(),
|
||||
subcomp_instr.get()));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Add an edge from the subcomputation to its parent node. If subcomp
|
||||
// belongs to a fusion node, it's drawn in place of the fusion instruction,
|
||||
// so there's no need to link those.
|
||||
// Add an edge from the subcomputation to its parent node. If subcomp
|
||||
// belongs to a fusion node, it's drawn in place of the fusion instruction,
|
||||
// so there's no need to link those.
|
||||
if (parent_instr->opcode() != HloOpcode::kFusion) {
|
||||
VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to "
|
||||
<< parent_instr->name() << " as " << next_edge_id_;
|
||||
edge_ids_.insert(
|
||||
@ -631,6 +616,14 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
|
||||
if (!filter_.Show(instr.get())) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Dump subcomputations within instr.
|
||||
for (const HloComputation* subcomp : instr->called_computations()) {
|
||||
if (ShouldShowSubcomputation(subcomp)) {
|
||||
StrAppend(&g, DumpSubcomputation(subcomp, instr.get()));
|
||||
}
|
||||
}
|
||||
|
||||
StrAppend(&g, DumpInstruction(instr.get()));
|
||||
}
|
||||
return g;
|
||||
@ -638,6 +631,14 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
|
||||
|
||||
string HloDotDumper::DumpRootTag() {
|
||||
HloInstruction* from = computation_->root_instruction();
|
||||
|
||||
// Fusion nodes are expanded inline, so if root is an expanded fusion node,
|
||||
// walk up the graph until we find a node that isn't.
|
||||
while (from->opcode() == HloOpcode::kFusion &&
|
||||
ShouldShowFusionSubcomputation(from)) {
|
||||
from = from->fused_expression_root();
|
||||
}
|
||||
|
||||
auto from_id = InstructionId(from);
|
||||
|
||||
if (!filter_.Show(from)) {
|
||||
@ -678,7 +679,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) {
|
||||
// Omit the fusion node if its subcomputation is drawn, since the
|
||||
// subcomputation will be drawn inline.
|
||||
if (instr->opcode() == HloOpcode::kFusion &&
|
||||
filter_.ShowFusionSubcomputation(instr)) {
|
||||
ShouldShowFusionSubcomputation(instr)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
@ -937,7 +938,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
|
||||
// Show the shape and layout of the instruction, unless it's an inlined fusion
|
||||
// node -- there the shape and layout is present in the output node.
|
||||
if (instr->opcode() != HloOpcode::kFusion ||
|
||||
!filter_.ShowFusionSubcomputation(instr)) {
|
||||
!ShouldShowFusionSubcomputation(instr)) {
|
||||
string instr_shape = ShapeUtil::HumanString(instr->shape());
|
||||
|
||||
// Show layout of non-tuple shapes with more than one dimension.
|
||||
@ -982,7 +983,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) {
|
||||
// fusion node and the node's subcomputation is shown, we draw our edge
|
||||
// starting at the fusion node's root instead of at the fusion node itself.
|
||||
if (from->opcode() == HloOpcode::kFusion &&
|
||||
filter_.ShowFusionSubcomputation(from)) {
|
||||
ShouldShowFusionSubcomputation(from)) {
|
||||
from = from->fused_expression_root();
|
||||
}
|
||||
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
|
||||
@ -1147,6 +1148,11 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) {
|
||||
}
|
||||
}
|
||||
|
||||
// Traverse into instr's nested computations.
|
||||
for (const HloComputation* computation : instr->called_computations()) {
|
||||
worklist.push_back({computation->root_instruction(), depth + 1});
|
||||
}
|
||||
|
||||
// Traverse into instr's users, unless:
|
||||
//
|
||||
// - there are a ton of them, in which case they're probably not
|
||||
|
122
tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
Normal file
122
tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc
Normal file
@ -0,0 +1,122 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_utils.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using ::tensorflow::strings::StrCat;
|
||||
using ::testing::HasSubstr;
|
||||
|
||||
string TestName() {
|
||||
return ::testing::UnitTest::GetInstance()->current_test_info()->name();
|
||||
}
|
||||
|
||||
class DotRenderer : public hlo_graph_dumper::GraphRendererInterface {
|
||||
public:
|
||||
string RenderGraph(const string& graph, GraphKind graph_kind,
|
||||
const DebugOptions& debug_options) override {
|
||||
return graph;
|
||||
}
|
||||
|
||||
private:
|
||||
string last_graph_;
|
||||
};
|
||||
|
||||
XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits<int>::max());
|
||||
|
||||
TEST(HloGraphDumperTest, NestedFusion) {
|
||||
HloComputation::Builder b("b");
|
||||
|
||||
// Build param0 + param1 + param2 + param3 + param4.
|
||||
auto shape = ShapeUtil::MakeShape(F32, {10, 100});
|
||||
std::vector<HloInstruction*> params;
|
||||
for (int i = 0; i <= 4; ++i) {
|
||||
params.push_back(b.AddInstruction(
|
||||
HloInstruction::CreateParameter(i, shape, StrCat("param", i))));
|
||||
}
|
||||
std::vector<HloInstruction*> sums;
|
||||
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, params[0], params[1])));
|
||||
for (int i = 0; i <= 2; ++i) {
|
||||
sums.push_back(b.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape, HloOpcode::kAdd, sums[i], params[i + 2])));
|
||||
}
|
||||
|
||||
HloModule m(TestName());
|
||||
m.AddEntryComputation(b.Build());
|
||||
HloComputation* root_computation = m.entry_computation();
|
||||
|
||||
// Fuse into fusion(param0 + param1 + param2 + param3 + param4).
|
||||
auto* outer_fusion = root_computation->CreateFusionInstruction(
|
||||
{sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
// Fusing invalidates the pointers in sums -- the instructions are cloned when
|
||||
// they're moved to the new computation. Get the updated pointers to sums.
|
||||
std::vector<HloInstruction*> fused_sums;
|
||||
for (auto* instr : outer_fusion->fused_instructions_computation()
|
||||
->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() == HloOpcode::kAdd) {
|
||||
fused_sums.push_back(instr);
|
||||
}
|
||||
}
|
||||
|
||||
// Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4).
|
||||
auto* inner_fusion =
|
||||
outer_fusion->fused_instructions_computation()->CreateFusionInstruction(
|
||||
{fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop);
|
||||
|
||||
// Generate the graph; all nodes should be present.
|
||||
string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"",
|
||||
DebugOptions());
|
||||
for (const HloComputation* computation :
|
||||
{root_computation, //
|
||||
inner_fusion->fused_instructions_computation(),
|
||||
outer_fusion->fused_instructions_computation()}) {
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
computation->instructions()) {
|
||||
EXPECT_THAT(graph, HasSubstr(instruction->name()));
|
||||
}
|
||||
}
|
||||
|
||||
// Dump a neighborhood around one of the inner sum nodes. We don't really
|
||||
// care that the outer nodes are omitted -- whether they are or not is based
|
||||
// fiddly heuristics -- but we do care that the node we asked for is printed.
|
||||
const HloInstruction* inner_sum = nullptr;
|
||||
for (const std::unique_ptr<HloInstruction>& instruction :
|
||||
inner_fusion->fused_instructions_computation()->instructions()) {
|
||||
if (instruction->opcode() == HloOpcode::kAdd) {
|
||||
inner_sum = instruction.get();
|
||||
break;
|
||||
}
|
||||
}
|
||||
ASSERT_NE(inner_sum, nullptr);
|
||||
EXPECT_THAT(
|
||||
hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1),
|
||||
HasSubstr(inner_sum->name()));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
Loading…
Reference in New Issue
Block a user