Merge HLO module events of the same step.
PiperOrigin-RevId: 304510304 Change-Id: I866405bfdd28d573355fa455c8434553640fc3c8
This commit is contained in:
parent
ec684ae119
commit
9ba3ff2df6
@ -44,10 +44,8 @@ class DerivedXLineBuilder {
|
||||
public:
|
||||
DerivedXLineBuilder(XPlaneBuilder* plane, int64 line_id,
|
||||
absl::string_view name, int64 timestamp_ns,
|
||||
std::vector<DerivedXLineBuilder*> dependent_lines,
|
||||
bool try_expand)
|
||||
: line_(plane->GetOrCreateLine(line_id)),
|
||||
try_expand_(try_expand) {
|
||||
std::vector<DerivedXLineBuilder*> dependent_lines)
|
||||
: line_(plane->GetOrCreateLine(line_id)) {
|
||||
line_.SetName(name);
|
||||
line_.SetTimestampNs(timestamp_ns);
|
||||
dependent_lines_ = std::move(dependent_lines);
|
||||
@ -71,12 +69,12 @@ class DerivedXLineBuilder {
|
||||
}
|
||||
|
||||
private:
|
||||
// If the last event of the given level has the same metadata and try_expand_
|
||||
// is true, expands it to include the time until the given event's (offset_ps
|
||||
// + duration_ps). Otherwise, adds a new event and clears last_event_by_level_
|
||||
// for the levels below the given level and all levels of the dependent lines.
|
||||
// Clearing last_event_by_level_ prevents a nested event from growing larger
|
||||
// than the parent event(s).
|
||||
// If the last event of the given level has the same metadata, expands it to
|
||||
// include the time until the given event's (offset_ps + duration_ps).
|
||||
// Otherwise, adds a new event and clears last_event_by_level_ for the levels
|
||||
// below the given level and all levels of the dependent lines. Clearing
|
||||
// last_event_by_level_ prevents a nested event from growing larger than the
|
||||
// parent event(s).
|
||||
void ExpandOrAddLevelEvent(const XEvent& event, int level) {
|
||||
int64 offset_ps = event.offset_ps();
|
||||
int64 duration_ps = event.duration_ps();
|
||||
@ -84,8 +82,7 @@ class DerivedXLineBuilder {
|
||||
// If last_event is not nullptr, its offset must be less than or equal to
|
||||
// the given event's offset.
|
||||
DCHECK(!last_event || last_event->OffsetPs() <= offset_ps);
|
||||
if (try_expand_ && last_event &&
|
||||
last_event->MetadataId() == event.metadata_id()) {
|
||||
if (last_event && last_event->MetadataId() == event.metadata_id()) {
|
||||
// If last_event is not nullptr and metadata is same, merge the given
|
||||
// event into last_event.
|
||||
last_event->SetDurationPs((offset_ps + duration_ps) -
|
||||
@ -108,7 +105,6 @@ class DerivedXLineBuilder {
|
||||
XLineBuilder line_;
|
||||
absl::flat_hash_map<int, absl::optional<XEventBuilder>> last_event_by_level_;
|
||||
std::vector<DerivedXLineBuilder*> dependent_lines_;
|
||||
bool try_expand_;
|
||||
};
|
||||
|
||||
const absl::string_view kDerivedLineSteps = "Steps";
|
||||
@ -185,19 +181,18 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
|
||||
|
||||
XPlaneBuilder plane(device_trace);
|
||||
DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kDerivedLineTensorFlowOps,
|
||||
start_timestamp_ns, {}, /*try_expand=*/true);
|
||||
DerivedXLineBuilder tf_name_scope(
|
||||
&plane, kThreadIdTfNameScope, kDerivedLineTensorFlowNameScope,
|
||||
start_timestamp_ns, {&tf_ops}, /*try_expand=*/true);
|
||||
start_timestamp_ns, {});
|
||||
DerivedXLineBuilder tf_name_scope(&plane, kThreadIdTfNameScope,
|
||||
kDerivedLineTensorFlowNameScope,
|
||||
start_timestamp_ns, {&tf_ops});
|
||||
DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kDerivedLineXlaOps,
|
||||
start_timestamp_ns, {}, /*try_expand=*/true);
|
||||
DerivedXLineBuilder hlo_modules(
|
||||
&plane, kThreadIdHloModule, kDerivedLineXlaModules, start_timestamp_ns,
|
||||
{&tf_ops, &tf_name_scope, &hlo_ops}, /*try_expand=*/false);
|
||||
start_timestamp_ns, {});
|
||||
DerivedXLineBuilder hlo_modules(&plane, kThreadIdHloModule,
|
||||
kDerivedLineXlaModules, start_timestamp_ns,
|
||||
{&tf_ops, &tf_name_scope, &hlo_ops});
|
||||
DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kDerivedLineSteps,
|
||||
start_timestamp_ns,
|
||||
{&tf_ops, &tf_name_scope, &hlo_ops},
|
||||
/*try_expand=*/true);
|
||||
{&tf_ops, &tf_name_scope, &hlo_ops, &hlo_modules});
|
||||
int64 group_id_stat_metadata_id =
|
||||
plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id();
|
||||
int64 step_name_stat_metadata_id =
|
||||
|
@ -36,7 +36,7 @@ TEST(DerivedTimelineTest, EmptySpaceTest) {
|
||||
EXPECT_EQ(space.planes_size(), 0);
|
||||
}
|
||||
|
||||
// Checks that HLO module events are not expanded.
|
||||
// Checks that HLO module events are expanded.
|
||||
TEST(DerivedTimelineTest, HloModuleNameTest) {
|
||||
const absl::string_view kHloModuleName = "hlo_module";
|
||||
const absl::string_view kKernelDetails = "kernel_details";
|
||||
@ -69,7 +69,7 @@ TEST(DerivedTimelineTest, HloModuleNameTest) {
|
||||
plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) {
|
||||
if (line_visitor.Id() == 0) return;
|
||||
EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule);
|
||||
EXPECT_EQ(line_visitor.NumEvents(), 2);
|
||||
EXPECT_EQ(line_visitor.NumEvents(), 1);
|
||||
line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) {
|
||||
EXPECT_EQ(event_visitor.Name(), kHloModuleName);
|
||||
});
|
||||
|
Loading…
x
Reference in New Issue
Block a user