Use the event name as part of the step name for the explicit root events.
PiperOrigin-RevId: 318634056 Change-Id: I2860534f4ebe62e732306a39a6a8fd57f6366b16
This commit is contained in:
parent
4aa879aab5
commit
7585042327
@ -91,6 +91,9 @@ void ConvertXPlaneToTraceEvents(uint32 device_id, const XPlaneVisitor& xplane,
|
||||
xevent.ForEachStat([&](const XStatVisitor& stat) {
|
||||
if (stat.ValueCase() == XStat::VALUE_NOT_SET) return;
|
||||
if (IsInternalStat(stat.Type())) return;
|
||||
if (stat.Type() == StatType::kStepName) {
|
||||
event->set_name(stat.ToString());
|
||||
}
|
||||
args[std::string(stat.Name())] = stat.ToString();
|
||||
});
|
||||
});
|
||||
|
@ -139,12 +139,25 @@ bool HasFunctionRun(EventNode* event_node) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool IsImplicitRootEvent(const XEventVisitor& event) {
|
||||
static const auto* const kImplicitRootEvents = new absl::flat_hash_set<int64>{
|
||||
HostEventType::kFunctionRun, HostEventType::kSessionRun,
|
||||
HostEventType::kRunGraph, HostEventType::kExecutorStateProcess};
|
||||
return event.Type().has_value() &&
|
||||
kImplicitRootEvents->contains(*event.Type());
|
||||
}
|
||||
|
||||
void ProcessRootEvent(int64 group_id, EventNode* root_event,
|
||||
EventGroupNameMap* event_group_name_map) {
|
||||
root_event->PropagateGroupId(group_id);
|
||||
std::string group_name = root_event->GetGroupName();
|
||||
// TODO(jihochoi): change event name instead.
|
||||
root_event->AddStepName(group_name);
|
||||
if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
|
||||
// Add the `step_name` stat for the user-defined root events only. When an
|
||||
// XEvent is converted to a trace event, the trace event name is set to the
|
||||
// `step_name` stat's value if present.
|
||||
root_event->AddStepName(group_name);
|
||||
}
|
||||
event_group_name_map->emplace(group_id, std::move(group_name));
|
||||
}
|
||||
|
||||
@ -336,6 +349,8 @@ std::string EventNode::GetGroupName() const {
|
||||
if (absl::optional<XStatVisitor> stat =
|
||||
GetContextStat(StatType::kGraphType)) {
|
||||
absl::StrAppend(&name, stat->StrOrRefValue(), " ");
|
||||
} else if (!(IsImplicitRootEvent(visitor_))) {
|
||||
absl::StrAppend(&name, GetEventVisitor().Name(), " ");
|
||||
}
|
||||
int64 step_num = group_id_.value_or(0);
|
||||
if (absl::optional<XStatVisitor> stat = GetContextStat(StatType::kIterNum)) {
|
||||
|
@ -40,8 +40,9 @@ TEST(GroupEventsTest, GroupGpuTraceTest) {
|
||||
host_plane_builder.ReserveLines(2);
|
||||
|
||||
auto main_thread = host_plane_builder.GetOrCreateLine(0);
|
||||
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext,
|
||||
0, 100, {{StatType::kStepNum, kStepNum}});
|
||||
CreateXEvent(
|
||||
&host_plane_builder, &main_thread, HostEventType::kTraceContext, 0, 100,
|
||||
{{StatType::kGraphType, "train"}, {StatType::kStepNum, kStepNum}});
|
||||
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
|
||||
10, 90, {{StatType::kStepId, kStepId}});
|
||||
|
||||
@ -68,7 +69,7 @@ TEST(GroupEventsTest, GroupGpuTraceTest) {
|
||||
device_plane->lines(0).events(0).stats(1)),
|
||||
StatType::kGroupId);
|
||||
EXPECT_EQ(event_group_name_map.size(), 1);
|
||||
EXPECT_EQ(event_group_name_map[0], "123");
|
||||
EXPECT_EQ(event_group_name_map[0], "train 123");
|
||||
}
|
||||
|
||||
TEST(GroupEventsTest, GroupTensorFlowLoopTest) {
|
||||
|
Loading…
Reference in New Issue
Block a user