Merge HLO module events of the same step.

PiperOrigin-RevId: 304510304
Change-Id: I866405bfdd28d573355fa455c8434553640fc3c8
This commit is contained in:
Jiho Choi 2020-04-02 17:21:41 -07:00 committed by TensorFlower Gardener
parent ec684ae119
commit 9ba3ff2df6
2 changed files with 20 additions and 25 deletions

View File

@ -44,10 +44,8 @@ class DerivedXLineBuilder {
public:
DerivedXLineBuilder(XPlaneBuilder* plane, int64 line_id,
absl::string_view name, int64 timestamp_ns,
std::vector<DerivedXLineBuilder*> dependent_lines,
bool try_expand)
: line_(plane->GetOrCreateLine(line_id)),
try_expand_(try_expand) {
std::vector<DerivedXLineBuilder*> dependent_lines)
: line_(plane->GetOrCreateLine(line_id)) {
line_.SetName(name);
line_.SetTimestampNs(timestamp_ns);
dependent_lines_ = std::move(dependent_lines);
@ -71,12 +69,12 @@ class DerivedXLineBuilder {
}
private:
// If the last event of the given level has the same metadata and try_expand_
// is true, expands it to include the time until the given event's (offset_ps
// + duration_ps). Otherwise, adds a new event and clears last_event_by_level_
// for the levels below the given level and all levels of the dependent lines.
// Clearing last_event_by_level_ prevents a nested event from growing larger
// than the parent event(s).
// If the last event of the given level has the same metadata, expands it to
// include the time until the given event's (offset_ps + duration_ps).
// Otherwise, adds a new event and clears last_event_by_level_ for the levels
// below the given level and all levels of the dependent lines. Clearing
// last_event_by_level_ prevents a nested event from growing larger than the
// parent event(s).
void ExpandOrAddLevelEvent(const XEvent& event, int level) {
int64 offset_ps = event.offset_ps();
int64 duration_ps = event.duration_ps();
@ -84,8 +82,7 @@ class DerivedXLineBuilder {
// If last_event is not nullptr, its offset must be less than or equal to
// the given event's offset.
DCHECK(!last_event || last_event->OffsetPs() <= offset_ps);
if (try_expand_ && last_event &&
last_event->MetadataId() == event.metadata_id()) {
if (last_event && last_event->MetadataId() == event.metadata_id()) {
// If last_event is not nullptr and metadata is same, merge the given
// event into last_event.
last_event->SetDurationPs((offset_ps + duration_ps) -
@ -108,7 +105,6 @@ class DerivedXLineBuilder {
XLineBuilder line_;
absl::flat_hash_map<int, absl::optional<XEventBuilder>> last_event_by_level_;
std::vector<DerivedXLineBuilder*> dependent_lines_;
bool try_expand_;
};
const absl::string_view kDerivedLineSteps = "Steps";
@ -185,19 +181,18 @@ void DeriveEventsFromAnnotations(const SymbolResolver& symbol_resolver,
XPlaneBuilder plane(device_trace);
DerivedXLineBuilder tf_ops(&plane, kThreadIdTfOp, kDerivedLineTensorFlowOps,
start_timestamp_ns, {}, /*try_expand=*/true);
DerivedXLineBuilder tf_name_scope(
&plane, kThreadIdTfNameScope, kDerivedLineTensorFlowNameScope,
start_timestamp_ns, {&tf_ops}, /*try_expand=*/true);
start_timestamp_ns, {});
DerivedXLineBuilder tf_name_scope(&plane, kThreadIdTfNameScope,
kDerivedLineTensorFlowNameScope,
start_timestamp_ns, {&tf_ops});
DerivedXLineBuilder hlo_ops(&plane, kThreadIdHloOp, kDerivedLineXlaOps,
start_timestamp_ns, {}, /*try_expand=*/true);
DerivedXLineBuilder hlo_modules(
&plane, kThreadIdHloModule, kDerivedLineXlaModules, start_timestamp_ns,
{&tf_ops, &tf_name_scope, &hlo_ops}, /*try_expand=*/false);
start_timestamp_ns, {});
DerivedXLineBuilder hlo_modules(&plane, kThreadIdHloModule,
kDerivedLineXlaModules, start_timestamp_ns,
{&tf_ops, &tf_name_scope, &hlo_ops});
DerivedXLineBuilder steps(&plane, kThreadIdStepInfo, kDerivedLineSteps,
start_timestamp_ns,
{&tf_ops, &tf_name_scope, &hlo_ops},
/*try_expand=*/true);
{&tf_ops, &tf_name_scope, &hlo_ops, &hlo_modules});
int64 group_id_stat_metadata_id =
plane.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId))->id();
int64 step_name_stat_metadata_id =

View File

@ -36,7 +36,7 @@ TEST(DerivedTimelineTest, EmptySpaceTest) {
EXPECT_EQ(space.planes_size(), 0);
}
// Checks that HLO module events are not expanded.
// Checks that HLO module events are expanded.
TEST(DerivedTimelineTest, HloModuleNameTest) {
const absl::string_view kHloModuleName = "hlo_module";
const absl::string_view kKernelDetails = "kernel_details";
@ -69,7 +69,7 @@ TEST(DerivedTimelineTest, HloModuleNameTest) {
plane_visitor.ForEachLine([&](const XLineVisitor& line_visitor) {
if (line_visitor.Id() == 0) return;
EXPECT_EQ(line_visitor.Id(), kThreadIdHloModule);
EXPECT_EQ(line_visitor.NumEvents(), 2);
EXPECT_EQ(line_visitor.NumEvents(), 1);
line_visitor.ForEachEvent([&](const XEventVisitor& event_visitor) {
EXPECT_EQ(event_visitor.Name(), kHloModuleName);
});