[XLA] Fix hlo_graph_dumper: don't crash if the computation has a constant root instruction.
PiperOrigin-RevId: 180285687
This commit is contained in:
parent
ade8058c51
commit
711b10c280
@ -2072,6 +2072,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:window_util",
|
||||
"//tensorflow/compiler/xla:xla_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:regexp_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/lib/gtl/optional.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
@ -508,8 +509,17 @@ stylesheet="
|
||||
|
||||
// The "to_node" value may be a NULL, indicating that this points to the
|
||||
// "root" tag rather than a normal node.
|
||||
int64 from_node_id = node_ids_.at(from_node);
|
||||
int64 to_node_id = to_node ? node_ids_.at(to_node) : root_node_id_;
|
||||
int64 from_node_id =
|
||||
tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1);
|
||||
if (from_node_id == -1) {
|
||||
LOG(FATAL) << from_node->name() << " was added to edges but not to nodes";
|
||||
}
|
||||
int64 to_node_id =
|
||||
to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1)
|
||||
: root_node_id_;
|
||||
if (to_node != nullptr && to_node_id == -1) {
|
||||
LOG(FATAL) << to_node->name() << " was added to edges but not to nodes";
|
||||
}
|
||||
|
||||
add_hover_css_rule("node", from_node_id, kBlue);
|
||||
add_hover_css_rule("node", to_node_id, kRed);
|
||||
@ -653,12 +663,15 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) {
|
||||
|
||||
string HloDotDumper::DumpRootTag() {
|
||||
const HloInstruction* from = GetNodeForEdge(computation_->root_instruction());
|
||||
auto from_id = InstructionId(from);
|
||||
|
||||
if (!filter_.Show(from)) {
|
||||
// We didn't display constants as separate nodes; so if the root is a
|
||||
// constant, we don't add root tag or edge for it.
|
||||
if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) {
|
||||
return "";
|
||||
}
|
||||
|
||||
auto from_id = InstructionId(from);
|
||||
|
||||
// The ID of the root computation is otherwise unused, so it makes a good ID
|
||||
// to use for the root-tag node. However, the edge_ids_ map requires a
|
||||
// HloInstruction* pointer for the 'to' value, so we use a NULL value there
|
||||
|
@ -117,5 +117,18 @@ TEST(HloGraphDumperTest, NestedFusion) {
|
||||
HasSubstr(inner_sum->name()));
|
||||
}
|
||||
|
||||
TEST(HloGraphDumperTest, Constant) {
|
||||
HloComputation::Builder b("b");
|
||||
auto instruction = b.AddInstruction(
|
||||
HloInstruction::CreateConstant(Literal::CreateR0<float>(-42)));
|
||||
instruction->set_name("i_am_a_constant_root_instruction");
|
||||
HloModule m(TestName());
|
||||
HloComputation* root_computation = m.AddEntryComputation(b.Build());
|
||||
string graph = hlo_graph_dumper::DumpGraph(
|
||||
*root_computation, /*label=*/"an_empty_graph", DebugOptions());
|
||||
EXPECT_THAT(graph, HasSubstr("an_empty_graph"));
|
||||
EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction")));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user