Merge HLO module events of the same step.
PiperOrigin-RevId: 304510304 Change-Id: I866405bfdd28d573355fa455c8434553640fc3c8
This commit is contained in:
parent
ec684ae119
commit
9ba3ff2df6
@ -44,10 +44,8 @@ class DerivedXLineBuilder {
|
|||||||
public:
|
public:
|
||||||
DerivedXLineBuilder(XPlaneBuilder* plane, int64 line_id,
|
DerivedXLineBuilder(XPlaneBuilder* plane, int64 line_id,
|
||||||
absl::string_view name, int64 timestamp_ns,
|
absl::string_view name, int64 timestamp_ns,
|
||||||
std::vector<DerivedXLineBuilder*> dependent_lines,
|
std::vector<DerivedXLineBuilder*> dependent_lines)
|
||||||
bool try_expand)
|
: line_(plane->GetOrCreateLine(line_id)) {
|
||||||
: line_(plane->GetOrCreateLine(line_id)),
|
|
||||||
try_expand_(try_expand) {
|
|
||||||
line_.SetName(name);
|
line_.SetName(name);
|
||||||
line_.SetTimestampNs(timestamp_ns);
|
line_.SetTimestampNs(timestamp_ns);
|
||||||
dependent_lines_ = std::move(dependent_lines);
|
dependent_lines_ = std::move(dependent_lines);
|
||||||
@ -71,12 +69,12 @@ class DerivedXLineBuilder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// If the last event of the given level has the same metadata and try_expand_
|
// If the last event of the given level has the same metadata, expands it to
|
||||||
// is true, expands it to include the time until the given event's (offset_ps
|
// include the time until the given event's (offset_ps + duration_ps).
|
||||||
// + duration_ps). Otherwise, adds a new event and clears last_event_by_level_
|
// Otherwise, adds a new event and clears last_event_by_level_ for the levels
|
||||||
// for the levels below the given level and all levels of the dependent lines.
|
// below the given level and all levels of the dependent lines. Clearing
|
||||||
// Clearing last_event_by_level_ prevents a nested event from growing larger
|
// last_event_by_level_ prevents a nested event from growing larger than the
|
||||||
// than the parent event(s).
|
// parent event(s).
|
||||||
void ExpandOrAddLevelEvent(const XEvent& event, int level) {
|
void ExpandOrAddLevelEvent(const XEvent& event, int level) {
|
||||||
int64 offset_ps = event.offset_ps();
|
int64 offset_ps = event.offset_ps();
|
||||||
int64 duration_ps = event.duration_ps();
|
int64 duration_ps = event.duration_ps();
|
||||||
@ -84,8 +82,7 @@ class DerivedXLineBuilder {
|
|||||||
// If last_event is not nullptr, its offset must be less than or equal to
|
// If last_event is not nullptr, its offset must be less than or equal to
|
||||||
// the given event's offset.
|
// the given event's offset.
|
||||||
DCHECK(!last_event || last_event->OffsetPs() <= offset_ps);
|
DCHECK(!last_event || last_event->OffsetPs() <= offset_ps);
|
||||||
if (try_expand_ && last_event &&
|
if (last_event && last_event->MetadataId() == event.metadata_id()) {
|
||||||
last_event->MetadataId() == event.metadata_id()) {
|
|
||||||
// If last_event is not nullptr and metadata is same, merge the given
|
// If last_event is not nullptr and metadata is same, merge the given
|
||||||
// event into last_event.
|
// event into last_event.
|
||||||
last_event->SetDurationPs((offset_ps + duration_ps) -
|
last_event->SetDurationPs((offset_ps + duration_ps) -
|
||||||
@ -108,7 +105,6 @@ class DerivedXLineBuilder {
|
|||||||
XLineBuilder line_;
|
XLineBuilder line_;
|
||||||
absl::flat_hash_map<int, absl::optional<XEventBuilder>> last_event_by_level_;
|
absl::flat_hash_map<int, absl::optional<XEventBuilder>> last_event_by_level_;
|
||||||
std::vector<DerivedXLineBuilder*> dependent_lines_;
|
std::vector<DerivedXLineBuilder*> dependent_lines_;
|
||||||
bool try_expand_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const absl::string_view kDerivedLineSteps = "Steps";
|
const absl::string_view kDerivedLineSteps = "Steps";
|
||||||
@ -185,19 +181,18 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
|
|||||||
|
|
||||||
XPlaneBuilder plane(device_trace);
|
XPlaneBuilder plane(device_trace);
|
||||||
DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kDerivedLineTensorFlowOps,
|
DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kDerivedLineTensorFlowOps,
|
||||||
start_timestamp_ns, {}, /*try_expand=*/true);
|
start_timestamp_ns, {});
|
||||||
DerivedXLineBuilder tf_name_scope(
|
DerivedXLineBuilder tf_name_scope(&plane, kThreadIdTfNameScope,
|
||||||
&plane, kThreadIdTfNameScope, kDerivedLineTensorFlowNameScope,
|
kDerivedLineTensorFlowNameScope,
|
||||||
start_timestamp_ns, {&tf_ops}, /*try_expand=*/true);
|
start_timestamp_ns, {&tf_ops});
|
||||||
DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kDerivedLineXlaOps,
|
DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kDerivedLineXlaOps,
|
||||||
start_timestamp_ns, {}, /*try_expand=*/true);
|
start_timestamp_ns, {});
|
||||||
DerivedXLineBuilder hlo_modules(
|
DerivedXLineBuilder hlo_modules(&plane, kThreadIdHloModule,
|
||||||
&plane, kThreadIdHloModule, kDerivedLineXlaModules, start_timestamp_ns,
|
kDerivedLineXlaModules, start_timestamp_ns,
|
||||||
{&tf_ops, &tf_name_scope, &hlo_ops}, /*try_expand=*/false);
|
{&tf_ops, &tf_name_scope, &hlo_ops});
|
||||||
DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kDerivedLineSteps,
|
DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kDerivedLineSteps,
|
||||||
start_timestamp_ns,
|
start_timestamp_ns,
|
||||||
{&tf_ops, &tf_name_scope, &hlo_ops},
|
{&tf_ops, &tf_name_scope, &hlo_ops, &hlo_modules});
|
||||||
/*try_expand=*/true);
|
|
||||||
int64 group_id_stat_metadata_id =
|
int64 group_id_stat_metadata_id =
|
||||||
plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id();
|
plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id();
|
||||||
int64 step_name_stat_metadata_id =
|
int64 step_name_stat_metadata_id =
|
||||||
|
@ -36,7 +36,7 @@ TEST(DerivedTimelineTest, EmptySpaceTest) {
|
|||||||
EXPECT_EQ(space.planes_size(), 0);
|
EXPECT_EQ(space.planes_size(), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Checks that HLO module events are not expanded.
|
// Checks that HLO module events are expanded.
|
||||||
TEST(DerivedTimelineTest, HloModuleNameTest) {
|
TEST(DerivedTimelineTest, HloModuleNameTest) {
|
||||||
const absl::string_view kHloModuleName = "hlo_module";
|
const absl::string_view kHloModuleName = "hlo_module";
|
||||||
const absl::string_view kKernelDetails = "kernel_details";
|
const absl::string_view kKernelDetails = "kernel_details";
|
||||||
@ -69,7 +69,7 @@ TEST(DerivedTimelineTest, HloModuleNameTest) {
|
|||||||
plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) {
|
plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) {
|
||||||
if (line_visitor.Id() == 0) return;
|
if (line_visitor.Id() == 0) return;
|
||||||
EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule);
|
EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule);
|
||||||
EXPECT_EQ(line_visitor.NumEvents(), 2);
|
EXPECT_EQ(line_visitor.NumEvents(), 1);
|
||||||
line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) {
|
line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) {
|
||||||
EXPECT_EQ(event_visitor.Name(), kHloModuleName);
|
EXPECT_EQ(event_visitor.Name(), kHloModuleName);
|
||||||
});
|
});
|
||||||
|
Loading…
x
Reference in New Issue
Block a user