diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index 8926bbed2b5..99b32c19a52 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -14,7 +14,7 @@ filegroup( visibility = ["//tensorflow/compiler/xla:internal"], ) -load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") tf_cc_binary( name = "hex_floats_to_packed_literal", @@ -234,3 +234,50 @@ tf_cc_binary( "//tensorflow/core:lib", ], ) + +tf_cc_test( + name = "hlo_extractor_test", + srcs = ["hlo_extractor_test.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "hlo_extractor", + srcs = ["hlo_extractor.cc"], + hdrs = ["hlo_extractor.h"], + deps = [ + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + ], +) + +tf_cc_binary( + name = "interactive_graphviz", + srcs = ["interactive_graphviz.cc"], + deps = [ + ":hlo_extractor", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service:gpu_plugin", + "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:local_service", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings", + ], +) diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.cc b/tensorflow/compiler/xla/tools/hlo_extractor.cc new file mode 100644 index 00000000000..f3ce5f99b0c --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.cc @@ -0,0 +1,159 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "tensorflow/compiler/xla/service/hlo_clone_context.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/status.h" + +namespace xla { +namespace { + +// Visitor that build a new HLO module with an entry computation and a root that +// is provided to the visit function. Only HLOs that are reachable from the new +// root instruction are included in the new module. +// +// The constructor allows specifying a set of boundary HLOs to prune the HLO +// graph. HLOs at the boundary are replaced with parameters. Can be nullptr +// which means no boundary, i.e. no HLOs are replaced with parameters. +class ExtractionVisitor : public ConstDfsHloVisitorWithDefault { + public: + explicit ExtractionVisitor( + const HloModule& old_module, + absl::flat_hash_set* boundary) + : old_module_(old_module), + module_(absl::make_unique("extracted", config_)), + clone_context_(module_.get()), + builder_("entry_computation"), + boundary_(boundary) {} + + Status HandleParameter(const HloInstruction* parameter) override { + // Entry parameters need renumbering. + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_++, parameter->shape(), parameter->name()); + clone_context_.MapInstruction(parameter, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + + Status DefaultAction(const HloInstruction* hlo) override { + // Replace instructions at the boundary with parameters, but leave constants + // untouched. + if (boundary_ != nullptr && boundary_->count(hlo) > 0) { + auto new_parameter = HloInstruction::CreateParameter( + parameter_number_, hlo->shape(), hlo->name()); + parameter_number_++; + clone_context_.MapInstruction(hlo, new_parameter.get()); + builder_.AddInstruction(std::move(new_parameter)); + return Status::OK(); + } + std::vector new_operands; + for (auto operand : hlo->operands()) { + new_operands.push_back(clone_context_.GetInstruction(operand)); + } + auto instruction = + hlo->CloneWithNewOperands(hlo->shape(), new_operands, &clone_context_); + builder_.AddInstruction(std::move(instruction)); + return Status::OK(); + } + + Status FinishVisit(const HloInstruction* /*root*/) override { + module_->AddEntryComputation(builder_.Build()); + // Rename HLOs so that their name matches the original. By default, + // HLOs get new unique names when adding a new entry computation to + // a module. + for (auto computation : old_module_.MakeComputationPostOrder()) { + for (auto old_instruction : computation->MakeInstructionPostOrder()) { + if (auto new_instruction = + clone_context_.FindInstruction(old_instruction)) { + new_instruction->SetAndSanitizeName(old_instruction->name()); + } + } + } + return Status::OK(); + } + + HloModule* module() { return module_.get(); } + + std::unique_ptr ConsumeModule() { return std::move(module_); } + + private: + const HloModule& old_module_; + HloModuleConfig config_; + std::unique_ptr module_; + HloCloneContext clone_context_; + HloComputation::Builder builder_; + absl::flat_hash_set* boundary_; + int64 parameter_number_ = 0; +}; + +void ComputeBoundary(const HloInstruction* root, int64 limit, + absl::flat_hash_set* boundary) { + std::deque worklist; + absl::flat_hash_map visited; + worklist.push_back(root); + visited.emplace(root, 0); + while (!worklist.empty()) { + auto hlo = worklist.front(); + worklist.pop_front(); + int64 hops = visited[hlo]; + if (hops > limit) { + boundary->insert(hlo); + continue; + } + for (const HloInstruction* operand : hlo->operands()) { + if (visited.count(operand)) { + continue; + } + worklist.push_back(operand); + visited.emplace(operand, hops + 1); + } + } +} + +} // namespace + +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height) { + absl::flat_hash_set boundary; + if (height != -1) { + ComputeBoundary(instruction, height, &boundary); + } + ExtractionVisitor visitor(*instruction->GetModule(), &boundary); + CHECK(instruction->Accept(&visitor).ok()); + + // The first pass may leave unused parameter instructions. Do another + // extraction pass to remove unused parameters. This is done because + // HloComputation does not allow removing parameters after the computation has + // been built. + ExtractionVisitor cleanup_visitor(*visitor.module(), /*boundary=*/nullptr); + TF_CHECK_OK(visitor.module()->entry_computation()->root_instruction()->Accept( + &cleanup_visitor)); + + HloVerifier verifier(/*layout_sensitive=*/false, + /*allow_mixed_precision=*/true); + TF_CHECK_OK(verifier.Run(cleanup_visitor.module()).status()); + return cleanup_visitor.ConsumeModule(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/hlo_extractor.h b/tensorflow/compiler/xla/tools/hlo_extractor.h new file mode 100644 index 00000000000..bc13dc7e438 --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" + +namespace xla { + +// Creates a new HLO module rooted with an entry computation rooted at the given +// instruction. +// +// By default (height == -1), the new computation includes all transitive +// operands of `root`. If you specify a different height, the new computation +// will include all instructions <= `height` hops away from `root`. +// Instructions at the boundary are replaced by parameters. +std::unique_ptr ExtractModule(HloInstruction* instruction, + int64 height = -1); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_HLO_EXTRACTOR_H_ diff --git a/tensorflow/compiler/xla/tools/hlo_extractor_test.cc b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc new file mode 100644 index 00000000000..c187222a11e --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_extractor_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = testing::opcode_matchers; + +using HloExtractorTest = HloTestBase; + +TEST_F(HloExtractorTest, ExtractTopLevel) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + negate = f32[4]{0} negate(f32[4]{0} param.0) + ROOT exp = f32[4]{0} exponential(f32[4]{0} negate) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Parameter(0)))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Parameter(0))); + } + + { + auto extracted_module = ExtractModule( + FindInstruction(hlo_module.get(), "negate"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Negate(op::Parameter(0))); + } +} + +TEST_F(HloExtractorTest, ExtractDag) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + param.0 = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(f32[4]{0} param.0) + negate = f32[4]{0} negate(f32[4]{0} tanh) + exp = f32[4]{0} exponential(f32[4]{0} negate) + ROOT add = f32[4]{0} add(f32[4]{0} negate, f32[4]{0} exp) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "exp")); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Exp(op::Negate(op::Tanh(op::Parameter(0))))); + } + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Parameter(0)), + op::Exp(op::Negate(op::Parameter(0))))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/2); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Negate(op::Tanh(op::Parameter(0))), + op::Exp(op::Negate(op::Tanh(op::Parameter(0)))))); + } +} + +TEST_F(HloExtractorTest, ExtractWithConstant) { + const string& hlo_string = R"( +HloModule test + +ENTRY %entry { + p = f32[4]{0} parameter(0) + tanh = f32[4]{0} tanh(p) + c = f32[4]{0} constant({1, 2, 3, 4}) + ROOT add = f32[4]{0} add(tanh, c) +} +)"; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr hlo_module, + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest())); + + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/0); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Parameter(0), op::Parameter(1))); + } + { + auto extracted_module = + ExtractModule(FindInstruction(hlo_module.get(), "add"), /*height=*/1); + EXPECT_THAT(extracted_module->entry_computation()->root_instruction(), + op::Add(op::Tanh(op::Parameter(0)), op::Constant())); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/interactive_graphviz.cc b/tensorflow/compiler/xla/tools/interactive_graphviz.cc new file mode 100644 index 00000000000..6c90cde5a75 --- /dev/null +++ b/tensorflow/compiler/xla/tools/interactive_graphviz.cc @@ -0,0 +1,652 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// A tool for interactively exploring graphviz dumps of HLO graphs. +// +// Input can be a binary HloSnapshot proto, a binary HloProto proto, or a +// textual HLO string. +// +// Generated visualization is opened in a new default browser window using +// /usr/bin/sensible-browser. + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view_utils.h" +#include "absl/strings/util.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/local_service.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/tools/hlo_extractor.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/subprocess.h" +#include "tensorflow/core/util/command_line_flags.h" +#if defined(PLATFORM_GOOGLE) +#include "util/readline/readline.h" +#endif + +namespace xla { +namespace tools { +namespace { + +bool ReadLine(const char *prompt, string *line) { +#if defined(PLATFORM_GOOGLE) + return util::ReadLine(prompt, line); +#else + std::cout << prompt; + return std::getline(std::cin, *line); +#endif +} + +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string hlo_snapshot; + string hlo_proto; + string hlo_text; + string platform; + string browser; +}; + +const char* const kUsage = R"( +This tool lets you load an XLA dump and then interactively explore its graphical +representation. + +Most models are too large to visualize in their entirety using graphviz, but +it's still useful to be able to look at the nodes "near" a particular node of +interest. + +If you pass --platform, this tool will compile the HloModule for the given +platform. This means that if you acquired your proto from a binary running at a +particular CL, the HLO graph it ran isn't necessarily the same as the one shown +here, unless this program was built at the same CL (and our compiler is +deterministic :). + +Be patient when starting this program if you give it a large input; it has to +compile the whole thing. + +Usage: + + interactive_graphviz -- \ + --{hlo_snapshot,hlo_proto,hlo_text}=path/to/binary_proto + --platform={CUDA,CPU,...} +)"; + +// Unless an explicit width is specified, we will render a neighborhood of +// kDefaultWidth nodes around the requested instruction. +constexpr int64 kDefaultWidth = 2; + +// When printing all paths between two nodes, we print out only this many nodes +// by default, truncating the graph if there are more nodes than this in the +// all-paths set. +constexpr int64 kDefaultMaxNumNodesInAllPaths = 100; + +using absl::EqualsIgnoreCase; + +// A global control for whether backend configuration display is enabled. +bool show_backend_config = true; + +HloInstruction* FindInstruction(const HloModule& module, string node_name) { + if (absl::StartsWith(node_name, "%")) { + node_name.erase(node_name.begin()); + } + for (const auto& computation : module.computations()) { + auto instrs = computation->instructions(); + auto it = absl::c_find_if(instrs, [&](const HloInstruction* instr) { + // Try with and without "%" at the beginning of the node name. + return EqualsIgnoreCase(instr->name(), node_name) || + EqualsIgnoreCase(instr->name(), absl::StrCat("%", node_name)); + }); + if (it != instrs.end()) { + return *it; + } + } + return nullptr; +} + +HloComputation* FindComputation(const HloModule& module, + const string& comp_name) { + for (auto* computation : module.computations()) { + if (EqualsIgnoreCase(computation->name(), comp_name)) { + return computation; + } + } + return nullptr; +} + +// Print a help message describing the various available commands. +void DoHelpCommand() { + std::cout << R"(Commands: + [] + Renders a neighborhood of nodes around . If + is not provided, the default value is )" + << kDefaultWidth << R"(. + allpaths [] + Renders a subset of all paths from one instruction to the other. Either + order of nodes is accepted. Shows the nodes in the all-paths set on the + shortest paths; default is )" + << kDefaultMaxNumNodesInAllPaths << R"(. + + Renders all nodes in . + backend_config [on|off] + Controls whether backend operation configuration information is printed. + list [name|op_name|op_type] + Lists all instructions whose name, metadata op_name, or metadata op_type + contains as a substring. + list computations + Lists all computations in the module. + info + info + Prints information about or . + extract + Creates a new HLO module with as entry computation root. If + is specified, the new computation contains nodes up to + nodes above the root. + help + Prints this usage information.)" + << std::endl; +} + +// Turn metadata-printing on or off. +void DoBackendConfigCommand(const std::vector& tokens) { + if (tokens.size() == 2 && tokens[1] == "on") { + show_backend_config = true; + } else if (tokens.size() == 2 && tokens[1] == "off") { + show_backend_config = false; + } else if (tokens.size() != 1) { + std::cerr << "(Illegal backend_config value. Use either 'on' or 'off'.)" + << std::endl; + } + std::cout << "Backend configuration display " + << (show_backend_config ? "ON" : "OFF") << std::endl; +} + +// List all computations in the module. +void DoListComputationsCommand(const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cout << R"(Illegal syntax; "list computations" takes no arguments.)"; + return; + } + if (module.entry_computation() != nullptr) { + std::cout << "Entry computation:" << std::endl; + std::cout << " " << module.entry_computation()->name() << std::endl + << std::endl; + } + std::cout << "Subcomputations:" << std::endl; + std::vector names; + for (const auto& computation : module.computations()) { + if (computation == module.entry_computation()) { + continue; + } + std::cout << " " << computation->name() << std::endl; + } +} + +// List all instructions matching a pattern. +void DoListCommand(const HloModule& module, const std::vector& tokens) { + string pattern = ""; + string type = "name"; + if (tokens.size() == 2) { + pattern = tokens[1]; + } else if (tokens.size() == 3) { + type = tokens[1]; + pattern = tokens[2]; + } else { + std::cout << "Illegal list query syntax. Use " + << R"("list [name|op_name|op_type] pattern".)" << std::endl; + return; + } + + std::cout << "Query results:" << std::endl; + for (const auto& computation : module.computations()) { + for (const auto& instr : computation->instructions()) { + if ((type == "name" && instr->name().find(pattern) != string::npos) || + (type == "op_name" && + instr->metadata().op_name().find(pattern) != string::npos) || + (type == "op_type" && + instr->metadata().op_type().find(pattern) != string::npos)) { + std::cout << " " << instr->name(); + std::cout << ", op_name '" << instr->metadata().op_name() << "'"; + std::cout << ", op_type '" << instr->metadata().op_type() << "'"; + std::cout << std::endl; + } + } + } +} + +// Print info about an instruction or computation. +void DoInfoCommand(const HloModule& module, const std::vector& tokens) { + if (tokens.size() != 2) { + std::cerr << "Illegal info query syntax. Use " + << R"("info name".)"; + return; + } + string node_name = tokens[1]; + + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << std::endl; + return; + } + + if (comp != nullptr) { + std::cout << "HloComputation " << comp->name() << std::endl; + if (comp->IsFusionComputation()) { + std::cout << " Fusion instruction: " << comp->FusionInstruction()->name() + << std::endl; + } + std::cout << " Parameters:" << std::endl; + for (const auto& param : comp->parameter_instructions()) { + std::cout << " " << param->name() << " (" + << ShapeUtil::HumanStringWithLayout(param->shape()) << ")" + << std::endl; + } + HloInstruction* root = comp->root_instruction(); + std::cout << " Root instruction: " << root->name() << " (" + << ShapeUtil::HumanStringWithLayout(root->shape()) << ")" + << std::endl; + + auto embedded_computations = comp->MakeEmbeddedComputationsList(); + std::cout << " " << embedded_computations.size() << " embedded computation" + << (embedded_computations.size() != 1 ? "s" : "") + << (!embedded_computations.empty() ? ":" : ".") << std::endl; + for (const HloComputation* c : embedded_computations) { + std::cout << " " << c->name() << std::endl; + } + + // Find which computations reference comp as an embedded computation. + std::vector users; + for (const HloComputation* c : module.computations()) { + if (absl::c_linear_search(c->MakeEmbeddedComputationsList(), comp)) { + users.push_back(c); + } + } + std::cout << " Used by " << users.size() << " computation" + << (users.size() != 1 ? "s" : "") << (!users.empty() ? ":" : "."); + for (const HloComputation* c : users) { + std::cout << " " << c->name() << std::endl; + } + } else { + std::cout << "HloInstruction " << instr->name() << std::endl; + std::cout << " Parent computation: " << instr->parent()->name() + << std::endl; + std::cout << " Opcode: " << HloOpcodeString(instr->opcode()) << std::endl; + std::cout << " Shape: " << ShapeUtil::HumanStringWithLayout(instr->shape()) + << std::endl; + std::cout << " Metadata:" << std::endl; + if (!instr->metadata().op_name().empty()) { + std::cout << " Name: " << instr->metadata().op_name() << std::endl; + } + if (!instr->metadata().op_type().empty()) { + std::cout << " Type: " << instr->metadata().op_type() << std::endl; + } + if (!instr->raw_backend_config_string().empty()) { + std::cout << " Backend configuration: " + << instr->raw_backend_config_string() << std::endl; + } + if (instr->opcode() == HloOpcode::kFusion) { + std::cout << " Fusion kind: " << xla::ToString(instr->fusion_kind()) + << std::endl; + std::cout << " Fusion computation: " + << instr->fused_instructions_computation()->name() << std::endl; + std::cout << " Fused computation root: " + << instr->fused_expression_root()->name() << std::endl; + } + std::cout << " Operands:" << std::endl; + for (HloInstruction* operand : instr->operands()) { + std::cout << " " << operand->name() << " (" + << ShapeUtil::HumanStringWithLayout(operand->shape()) << ")" + << std::endl; + } + std::cout << " Users:" << std::endl; + for (HloInstruction* user : instr->users()) { + std::cout << " " << user->name() << std::endl; + } + if (instr->parent()->root_instruction() == instr) { + std::cout << " Root instruction of " << instr->parent()->name() + << std::endl; + } + } +} + +void DoExtractCommand(const HloModule& module, + absl::Span tokens) { + if (tokens.size() > 3) { + std::cerr << R"(Illegal input. Enter e.g. "extract %fusion.1 2")" + << std::endl; + return; + } + + // Find the node with the given name. + string node_name = tokens[1]; + HloInstruction* instr = FindInstruction(module, node_name); + if (!instr) { + std::cerr << "Couldn't find HloInstruction named " << node_name << "." + << std::endl; + return; + } + + int64 height = -1; + if (tokens.size() == 3) { + if (!absl::SimpleAtoi(tokens[2], &height)) { + std::cerr << "Can't parse '" << tokens[2] << "' as an integer." + << std::endl; + return; + } + } + + auto extracted_module = ExtractModule(instr, height); + std::cout << extracted_module->ToString( + HloPrintOptions::ShortParsable().set_print_backend_config( + show_backend_config)) + << std::endl; +} + +// Checks if there is a use-def path from `from` to `to`. +bool ExistsPathFromTo(const HloInstruction* from, const HloInstruction* to) { + std::unordered_set visited; + std::vector to_visit = {from}; + while (!to_visit.empty()) { + auto* n = to_visit.back(); + if (n == to) { + return true; + } + to_visit.pop_back(); + visited.insert(n); + for (auto* user : n->users()) { + if (!visited.count(user)) { + to_visit.push_back(user); + } + } + } + return false; +} + +void DisplayGraphHandle(const Options &opts, const string& handle) { + std::cout << handle << std::endl; + + // If it is a url, try to open it up in the user's browser too. + if (strings::StartsWithIgnoreCase(handle, "http://") || + strings::StartsWithIgnoreCase(handle, "https://") || + strings::StartsWithIgnoreCase(handle, "file://")) { + const char* browser_bin = opts.browser.empty() ? "/usr/bin/sensible-browser" + : opts.browser.c_str(); + tensorflow::SubProcess p; + p.SetProgram(browser_bin, {browser_bin, handle}); + p.Start(); + } else if (handle.empty()) { + std::cerr << "Unable to render graph, perhaps due to graphviz server " + "timeout. Run with --logtostderr to see." + << std::endl; + } else { + std::cerr << "\nExpected a URL, but got strange graph result (dumped " + "above). If this isn't what you expected, maybe file a bug?" + << std::endl; + } +} + +void DoAllPathsCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 4) { + std::cerr << R"(Illegal input. Enter e.g. "allpaths %add.4 %subtract.2" or +"allpaths add.4 subtract.2 42.)" + << std::endl; + return; + } + + int64 max_nodes = kDefaultMaxNumNodesInAllPaths; + if (tokens.size() == 4 && !absl::SimpleAtoi(tokens[3], &max_nodes)) { + std::cerr << "Can't parse '" << tokens[3] << "' as an integer." + << std::endl; + return; + } + + const HloInstruction* n1 = FindInstruction(module, tokens[1]); + if (!n1) { + std::cerr << "Couldn't find HloInstruction named " << tokens[1]; + return; + } + const HloInstruction* n2 = FindInstruction(module, tokens[2]); + if (!n2) { + std::cerr << "Couldn't find HloInstruction named " << tokens[2]; + return; + } + + // Is there a path from n1 to n2, or vice versa? + const HloInstruction* from; + const HloInstruction* to; + if (ExistsPathFromTo(n1, n2)) { + from = n1; + to = n2; + } else if (ExistsPathFromTo(n2, n1)) { + from = n2; + to = n1; + } else { + std::cerr << "No path from/to " << tokens[1] << " to/from " << tokens[2]; + return; + } + DisplayGraphHandle(opts, hlo_graph_dumper::DumpAllPathsFromTo( + *from, *to, max_nodes, /*show_backend_config=*/show_backend_config)); +} + +// Plot a given instruction neighborhood or computation with graphviz. +void DoPlotCommand(const Options& opts, const HloModule& module, + const std::vector& tokens) { + if (tokens.size() > 2) { + std::cerr << R"(Illegal input. Enter e.g. "%fusion.1 42" or "%fusion.1".)" + << std::endl; + return; + } + + string node_name = tokens[0]; + + // Find the node with the given name. + const HloInstruction* instr = FindInstruction(module, node_name); + const HloComputation* comp = FindComputation(module, node_name); + if (!instr && !comp) { + std::cerr << "Couldn't find HloInstruction or HloComputation named " + << node_name << "." << std::endl; + return; + } + + uint64 graph_width = kDefaultWidth; + if (tokens.size() == 2) { + if (comp) { + std::cerr << "Can only use graph-size parameter with instructions, but " + << node_name << " is a computation." << std::endl; + return; + } + if (!absl::SimpleAtoi(tokens[1], &graph_width)) { + std::cerr << "Can't parse '" << tokens[1] << "' as an integer." + << std::endl; + return; + } + } + + // Generate the graph and print the resulting string, which should be a + // graphviz url. + if (comp) { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpGraph( + *comp, "", comp->parent()->config().debug_options(), nullptr, + /*show_backend_config=*/show_backend_config)); + } else { + DisplayGraphHandle(opts, hlo_graph_dumper::DumpNeighborhoodAround( + *instr, graph_width, /*show_backend_config=*/show_backend_config)); + } +} + +// Run the main event loop, reading user commands and processing them. +void InteractiveDumpGraphs(const Options& opts, const HloModule& module) { + // This is an interactive tool, but some may use `extract` in non-tty + // environment anyway. Give them a clean hlo dump. + if (isatty(fileno(stdin))) { + std::cout << "\n\nLoaded module " << module.name() << "." << std::endl; + DoHelpCommand(); + } + for (string line; ReadLine("\ncommand: ", &line);) { + if (line.empty()) { + std::cout << R"(Enter e.g. "fusion.1 3" or "add.8".)" << std::endl + << R"(Enter "help" for help; ^D, "quit", or "exit" to exit.)" + << std::endl; + continue; + } + std::vector tokens = strings::Split(line, ' '); + if (tokens[0] == "quit" || tokens[0] == "exit") { + break; + } else if (tokens[0] == "help") { + DoHelpCommand(); + } else if (tokens[0] == "backend_config") { + DoBackendConfigCommand(tokens); + } else if (tokens[0] == "list") { + if (tokens.size() > 1 && tokens[1] == "computations") { + DoListComputationsCommand(module, tokens); + } else { + DoListCommand(module, tokens); + } + } else if (tokens[0] == "info") { + DoInfoCommand(module, tokens); + } else if (tokens[0] == "extract") { + DoExtractCommand(module, tokens); + } else if (tokens[0] == "allpaths") { + DoAllPathsCommand(opts, module, tokens); + } else { + DoPlotCommand(opts, module, tokens); + } + } +} + +void CheckFlags(const Options &opts) { + std::vector nonempty_proto_flags; + if (!opts.hlo_proto.empty()) { + nonempty_proto_flags.push_back("--hlo_proto"); + } + if (!opts.hlo_snapshot.empty()) { + nonempty_proto_flags.push_back("--hlo_snapshot"); + } + if (!opts.hlo_text.empty()) { + nonempty_proto_flags.push_back("--hlo_text"); + } + switch (nonempty_proto_flags.size()) { + case 1: + // We're good to go. + break; + case 0: + LOG(FATAL) << "Need one of the following options: " + << absl::StrJoin(nonempty_proto_flags, ", "); + default: + LOG(FATAL) << "Can only specify one of " + << absl::StrJoin(nonempty_proto_flags, ", "); + } +} + +void RealMain(const Options& opts) { + if (!isatty(fileno(stdin))) { + LOG(ERROR) << "\n\n*****************************************\n" + << "This is an interactive tool, but stdin is not a tty.\n" + << "*****************************************\n\n"; + } + + CheckFlags(opts); + + std::unique_ptr module; + if (!opts.hlo_snapshot.empty()) { + HloSnapshot snapshot; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + opts.hlo_snapshot, &snapshot)) + << "Can't open, read, or parse HloSnapshot proto at " + << opts.hlo_snapshot; + auto config = + HloModule::CreateModuleConfigFromProto(snapshot.hlo().hlo_module(), + xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + module = HloModule::CreateFromProto(snapshot.hlo().hlo_module(), config) + .ValueOrDie(); + } else if (!opts.hlo_proto.empty()) { + module = HloRunner::ReadModuleFromBinaryProtoFile( + opts.hlo_proto, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } else if (!opts.hlo_text.empty()) { + module = HloRunner::ReadModuleFromHloTextFile( + opts.hlo_text, xla::GetDebugOptionsFromFlags()) + .ValueOrDie(); + } + + // If a platform was specified, compile the module for that platform. + if (!opts.platform.empty()) { + se::Platform* platform = + PlatformUtil::GetPlatform(opts.platform).ValueOrDie(); + LOG(INFO) << "Compiling module for " << platform->Name(); + + se::StreamExecutor* executor = + platform->ExecutorForDevice(/*ordinal=*/0).ValueOrDie(); + auto compiler = Compiler::GetForPlatform(platform).ValueOrDie(); + module = compiler + ->RunHloPasses(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + auto executable = compiler + ->RunBackend(std::move(module), executor, + /*device_allocator=*/nullptr) + .ValueOrDie(); + InteractiveDumpGraphs(opts, executable->module()); + } else { + InteractiveDumpGraphs(opts, *module); + } +} + +} // namespace +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + xla::tools::Options opts; + opts.browser = "/usr/bin/sensible-browser"; + bool need_help = false; + const std::vector flag_list = { + tensorflow::Flag("hlo_snapshot", &opts.hlo_snapshot, + "HloSnapshot proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_proto", &opts.hlo_proto, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("hlo_text", &opts.hlo_text, + "XLA hlo proto to interactively dump to graphviz"), + tensorflow::Flag("platform", &opts.platform, + "Platform to compile for: CPU, CUDA, etc"), + tensorflow::Flag("browser", &opts.browser, + "Path to web browser used to display produced graphs."), + tensorflow::Flag("help", &need_help, + "Prints this help message"), + }; + xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(argv[0], &argc, &argv); + if (argc != 1 || !parse_ok || need_help) { + LOG(QFATAL) << usage; + } + xla::tools::RealMain(opts); + return 0; +}