diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 7418f2d8b11..b9650175f04 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1968,6 +1968,20 @@ cc_library( alwayslink = 1, ) +tf_cc_test( + name = "hlo_graph_dumper_test", + srcs = ["hlo_graph_dumper_test.cc"], + deps = [ + ":hlo", + ":hlo_graph_dumper", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:xla_proto", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + ], +) + cc_library( name = "transpose_folding", srcs = ["transpose_folding.cc"], diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index bba6fbfae04..39edfffcee5 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -340,11 +340,8 @@ class HloDotDumper { string Header(); string Footer(); - // Maps HloComputations we should dump to their parent instruction in the - // outer computation. - std::unordered_map - SubcomputationsToDump(); - + bool ShouldShowSubcomputation(const HloComputation* subcomp); + bool ShouldShowFusionSubcomputation(const HloInstruction* instr); string DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr); string DumpComputation(const HloComputation* comp); @@ -401,11 +398,6 @@ class HloDotDumper { string HloDotDumper::Dump() { string body; - for (const auto& kv : SubcomputationsToDump()) { - const HloComputation* subcomp = kv.first; - const HloInstruction* parent = kv.second; - StrAppend(&body, DumpSubcomputation(subcomp, parent)); - } StrAppend(&body, DumpComputation(computation_)); StrAppend(&body, DumpRootTag()); @@ -525,33 +517,36 @@ stylesheet=" string HloDotDumper::Footer() { return StrCat(Join(edges_, "\n"), "\n}"); } -std::unordered_map -HloDotDumper::SubcomputationsToDump() { - // Dump the subcomputations of each instruction that's shown and doesn't have - // its operands omitted. If an instruction has just one subcomputation and - // it's trivial, omit it: We'll display that subcomputation inlined into the - // instruction's node when we draw it. - std::unordered_map to_dump; - for (const auto& instr : computation_->instructions()) { - if (!filter_.Show(instr.get()) || - filter_.SomeOrAllOperandsOmitted(instr.get())) { - continue; - } - if (instr->opcode() == HloOpcode::kFusion) { - to_dump[instr->fused_instructions_computation()] = instr.get(); - } +bool HloDotDumper::ShouldShowFusionSubcomputation(const HloInstruction* instr) { + CHECK_EQ(instr->opcode(), HloOpcode::kFusion); + return ShouldShowSubcomputation(instr->fused_instructions_computation()); +} - for (const HloComputation* comp : instr->called_computations()) { - if (!MatchTrivialComputation(comp)) { - to_dump[comp] = instr.get(); - } +bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { + if (subcomp->IsFusionComputation()) { + const HloInstruction* fusion = subcomp->FusionInstruction(); + if (!filter_.Show(fusion) || filter_.SomeOrAllOperandsOmitted(fusion)) { + return false; } } - return to_dump; + + // Don't show trivial subcomputations on non-fusion nodes -- these are inlined + // into the graph. + if (!subcomp->IsFusionComputation() && MatchTrivialComputation(subcomp)) { + return false; + } + + // Show the subcomputation if we're showing any of its members. + return std::any_of(computation_->instructions().begin(), + computation_->instructions().end(), + [&](const std::unique_ptr& instr) { + return filter_.Show(instr.get()); + }); } string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr) { + VLOG(2) << "Dumping subcomputation " << subcomp->name(); const char* computation_fmt = R"(subgraph %s { %s label = <%s>; @@ -593,20 +588,10 @@ tooltip = " "; string comp_body = DumpComputation(subcomp); - if (parent_instr->opcode() == HloOpcode::kFusion) { - // Dump any nested fusion nodes. - for (const auto& subcomp_instr : subcomp->instructions()) { - if (subcomp_instr->opcode() == HloOpcode::kFusion) { - StrAppend( - &comp_body, - DumpSubcomputation(subcomp_instr->fused_instructions_computation(), - subcomp_instr.get())); - } - } - } else { - // Add an edge from the subcomputation to its parent node. If subcomp - // belongs to a fusion node, it's drawn in place of the fusion instruction, - // so there's no need to link those. + // Add an edge from the subcomputation to its parent node. If subcomp + // belongs to a fusion node, it's drawn in place of the fusion instruction, + // so there's no need to link those. + if (parent_instr->opcode() != HloOpcode::kFusion) { VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to " << parent_instr->name() << " as " << next_edge_id_; edge_ids_.insert( @@ -631,6 +616,14 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { if (!filter_.Show(instr.get())) { continue; } + + // Dump subcomputations within instr. + for (const HloComputation* subcomp : instr->called_computations()) { + if (ShouldShowSubcomputation(subcomp)) { + StrAppend(&g, DumpSubcomputation(subcomp, instr.get())); + } + } + StrAppend(&g, DumpInstruction(instr.get())); } return g; @@ -638,6 +631,14 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { string HloDotDumper::DumpRootTag() { HloInstruction* from = computation_->root_instruction(); + + // Fusion nodes are expanded inline, so if root is an expanded fusion node, + // walk up the graph until we find a node that isn't. + while (from->opcode() == HloOpcode::kFusion && + ShouldShowFusionSubcomputation(from)) { + from = from->fused_expression_root(); + } + auto from_id = InstructionId(from); if (!filter_.Show(from)) { @@ -678,7 +679,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { // Omit the fusion node if its subcomputation is drawn, since the // subcomputation will be drawn inline. if (instr->opcode() == HloOpcode::kFusion && - filter_.ShowFusionSubcomputation(instr)) { + ShouldShowFusionSubcomputation(instr)) { return ""; } @@ -937,7 +938,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { // Show the shape and layout of the instruction, unless it's an inlined fusion // node -- there the shape and layout is present in the output node. if (instr->opcode() != HloOpcode::kFusion || - !filter_.ShowFusionSubcomputation(instr)) { + !ShouldShowFusionSubcomputation(instr)) { string instr_shape = ShapeUtil::HumanString(instr->shape()); // Show layout of non-tuple shapes with more than one dimension. @@ -982,7 +983,7 @@ void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { // fusion node and the node's subcomputation is shown, we draw our edge // starting at the fusion node's root instead of at the fusion node itself. if (from->opcode() == HloOpcode::kFusion && - filter_.ShowFusionSubcomputation(from)) { + ShouldShowFusionSubcomputation(from)) { from = from->fused_expression_root(); } if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { @@ -1147,6 +1148,11 @@ NodeFilter MakeNodeFilter(const HloInstruction* root, int64 radius) { } } + // Traverse into instr's nested computations. + for (const HloComputation* computation : instr->called_computations()) { + worklist.push_back({computation->root_instruction(), depth + 1}); + } + // Traverse into instr's users, unless: // // - there are a ton of them, in which case they're probably not diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc new file mode 100644 index 00000000000..4015ee6cace --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/xla.pb.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace { + +using ::tensorflow::strings::StrCat; +using ::testing::HasSubstr; + +string TestName() { + return ::testing::UnitTest::GetInstance()->current_test_info()->name(); +} + +class DotRenderer : public hlo_graph_dumper::GraphRendererInterface { + public: + string RenderGraph(const string& graph, GraphKind graph_kind, + const DebugOptions& debug_options) override { + return graph; + } + + private: + string last_graph_; +}; + +XLA_REGISTER_GRAPH_RENDERER(DotRenderer, std::numeric_limits::max()); + +TEST(HloGraphDumperTest, NestedFusion) { + HloComputation::Builder b("b"); + + // Build param0 + param1 + param2 + param3 + param4. + auto shape = ShapeUtil::MakeShape(F32, {10, 100}); + std::vector params; + for (int i = 0; i <= 4; ++i) { + params.push_back(b.AddInstruction( + HloInstruction::CreateParameter(i, shape, StrCat("param", i)))); + } + std::vector sums; + sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, params[0], params[1]))); + for (int i = 0; i <= 2; ++i) { + sums.push_back(b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, sums[i], params[i + 2]))); + } + + HloModule m(TestName()); + m.AddEntryComputation(b.Build()); + HloComputation* root_computation = m.entry_computation(); + + // Fuse into fusion(param0 + param1 + param2 + param3 + param4). + auto* outer_fusion = root_computation->CreateFusionInstruction( + {sums[3], sums[2], sums[1], sums[0]}, HloInstruction::FusionKind::kLoop); + + // Fusing invalidates the pointers in sums -- the instructions are cloned when + // they're moved to the new computation. Get the updated pointers to sums. + std::vector fused_sums; + for (auto* instr : outer_fusion->fused_instructions_computation() + ->MakeInstructionPostOrder()) { + if (instr->opcode() == HloOpcode::kAdd) { + fused_sums.push_back(instr); + } + } + + // Fuse into fusion(fusion(param0 + param1 + param2) + param3 + param4). + auto* inner_fusion = + outer_fusion->fused_instructions_computation()->CreateFusionInstruction( + {fused_sums[1], fused_sums[0]}, HloInstruction::FusionKind::kLoop); + + // Generate the graph; all nodes should be present. + string graph = hlo_graph_dumper::DumpGraph(*root_computation, /*label=*/"", + DebugOptions()); + for (const HloComputation* computation : + {root_computation, // + inner_fusion->fused_instructions_computation(), + outer_fusion->fused_instructions_computation()}) { + for (const std::unique_ptr& instruction : + computation->instructions()) { + EXPECT_THAT(graph, HasSubstr(instruction->name())); + } + } + + // Dump a neighborhood around one of the inner sum nodes. We don't really + // care that the outer nodes are omitted -- whether they are or not is based + // fiddly heuristics -- but we do care that the node we asked for is printed. + const HloInstruction* inner_sum = nullptr; + for (const std::unique_ptr& instruction : + inner_fusion->fused_instructions_computation()->instructions()) { + if (instruction->opcode() == HloOpcode::kAdd) { + inner_sum = instruction.get(); + break; + } + } + ASSERT_NE(inner_sum, nullptr); + EXPECT_THAT( + hlo_graph_dumper::DumpNeighborhoodAround(*inner_sum, /*radius=*/1), + HasSubstr(inner_sum->name())); +} + +} // anonymous namespace +} // namespace xla