Refactor the code by (1) using EventVisitor in EventNode and (2) copying the first child EventNode when creating the virtual EventNode instead creating a new XEvent.
PiperOrigin-RevId: 313671525 Change-Id: I421f08025c49a9deda5231184d287535486ac13b
This commit is contained in:
parent
d29f1d6fe6
commit
1e22a99527
@ -103,16 +103,6 @@ int64 GetEventType(const XPlaneVisitor& visitor, const XEvent& event) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const XStat* GetStat(const XPlaneVisitor& visitor, const XEvent& event,
|
|
||||||
int64 stat_type) {
|
|
||||||
for (const auto& stat : event.stats()) {
|
|
||||||
if (visitor.GetStatType(stat) == stat_type) {
|
|
||||||
return &stat;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
|
void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
|
||||||
AddOrUpdateIntStat(*visitor.GetStatMetadataId(StatType::kGroupId), group_id,
|
AddOrUpdateIntStat(*visitor.GetStatMetadataId(StatType::kGroupId), group_id,
|
||||||
event);
|
event);
|
||||||
@ -146,8 +136,7 @@ bool NeedsVirtualEventsForAsyncExecutor(
|
|||||||
|
|
||||||
bool HasFunctionRun(EventNode* event_node) {
|
bool HasFunctionRun(EventNode* event_node) {
|
||||||
for (EventNode* child : event_node->GetChildren()) {
|
for (EventNode* child : event_node->GetChildren()) {
|
||||||
if (child->GetPlaneVisitor().GetEventType(child->GetEvent()) ==
|
if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
|
||||||
HostEventType::kFunctionRun) {
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -156,8 +145,21 @@ bool HasFunctionRun(EventNode* event_node) {
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
|
||||||
|
XEvent* raw_event)
|
||||||
|
: plane_(plane),
|
||||||
|
visitor_(plane, raw_line, raw_event),
|
||||||
|
raw_line_(raw_line),
|
||||||
|
raw_event_(raw_event) {}
|
||||||
|
|
||||||
|
EventNode::EventNode(const EventNode& event_node)
|
||||||
|
: plane_(event_node.plane_),
|
||||||
|
visitor_(event_node.plane_, event_node.raw_line_, event_node.raw_event_),
|
||||||
|
raw_line_(event_node.raw_line_),
|
||||||
|
raw_event_(event_node.raw_event_) {}
|
||||||
|
|
||||||
const XStat* EventNode::GetContextStat(int64 stat_type) const {
|
const XStat* EventNode::GetContextStat(int64 stat_type) const {
|
||||||
if (const XStat* stat = GetStat(*visitor_, *event_, stat_type)) {
|
if (const XStat* stat = visitor_.GetStats(stat_type)) {
|
||||||
return stat;
|
return stat;
|
||||||
} else if (parent_) {
|
} else if (parent_) {
|
||||||
return parent_->GetContextStat(stat_type);
|
return parent_->GetContextStat(stat_type);
|
||||||
@ -168,7 +170,7 @@ const XStat* EventNode::GetContextStat(int64 stat_type) const {
|
|||||||
std::string EventNode::GetGroupName() const {
|
std::string EventNode::GetGroupName() const {
|
||||||
std::vector<std::string> name_parts;
|
std::vector<std::string> name_parts;
|
||||||
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
|
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
|
||||||
XStatVisitor stat(visitor_, graph_type_stat);
|
XStatVisitor stat(plane_, graph_type_stat);
|
||||||
name_parts.push_back(stat.ToString());
|
name_parts.push_back(stat.ToString());
|
||||||
}
|
}
|
||||||
int64 step_num = group_id_.value_or(0);
|
int64 step_num = group_id_.value_or(0);
|
||||||
@ -184,7 +186,7 @@ std::string EventNode::GetGroupName() const {
|
|||||||
|
|
||||||
void EventNode::PropagateGroupId(int64 group_id) {
|
void EventNode::PropagateGroupId(int64 group_id) {
|
||||||
group_id_ = group_id;
|
group_id_ = group_id;
|
||||||
SetGroupId(*visitor_, group_id, event_);
|
SetGroupId(*plane_, group_id, raw_event_);
|
||||||
for (const auto& child : children_) {
|
for (const auto& child : children_) {
|
||||||
// Skip if it already belongs to a group. Some nodes may be added multiple
|
// Skip if it already belongs to a group. Some nodes may be added multiple
|
||||||
// times as child (e.g., sometimes async ops are executed synchronously and
|
// times as child (e.g., sometimes async ops are executed synchronously and
|
||||||
@ -196,13 +198,13 @@ void EventNode::PropagateGroupId(int64 group_id) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
void EventNode::AddStepName(absl::string_view step_name) {
|
void EventNode::AddStepName(absl::string_view step_name) {
|
||||||
AddOrUpdateStrStat(*visitor_->GetStatMetadataId(StatType::kStepName),
|
AddOrUpdateStrStat(*plane_->GetStatMetadataId(StatType::kStepName), step_name,
|
||||||
step_name, event_);
|
raw_event_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void EventNode::SetIsEager(bool is_eager) {
|
void EventNode::SetIsEager(bool is_eager) {
|
||||||
AddOrUpdateIntStat(*visitor_->GetStatMetadataId(StatType::kIsEager),
|
AddOrUpdateIntStat(*plane_->GetStatMetadataId(StatType::kIsEager),
|
||||||
is_eager ? 1 : 0, event_);
|
is_eager ? 1 : 0, raw_event_);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EventNode::IsEager() {
|
bool EventNode::IsEager() {
|
||||||
@ -213,14 +215,9 @@ bool EventNode::IsEager() {
|
|||||||
FindParent(HostEventType::kEagerKernelExecute) != nullptr;
|
FindParent(HostEventType::kEagerKernelExecute) != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool EventNode::IsNestedIn(EventNode* parent) {
|
|
||||||
return parent && IsNested(GetEvent(), parent->GetEvent());
|
|
||||||
}
|
|
||||||
|
|
||||||
EventNode* EventNode::FindParent(int64 event_type) {
|
EventNode* EventNode::FindParent(int64 event_type) {
|
||||||
if (parent_) {
|
if (parent_) {
|
||||||
if (GetEventType(parent_->GetPlaneVisitor(), parent_->GetEvent()) ==
|
if (parent_->GetEventVisitor().Type() == event_type) {
|
||||||
event_type) {
|
|
||||||
return parent_;
|
return parent_;
|
||||||
}
|
}
|
||||||
return parent_->FindParent(event_type);
|
return parent_->FindParent(event_type);
|
||||||
@ -233,10 +230,11 @@ void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
|
|||||||
for (auto& line : *plane->mutable_lines()) {
|
for (auto& line : *plane->mutable_lines()) {
|
||||||
std::vector<EventNode*> parent_nodes;
|
std::vector<EventNode*> parent_nodes;
|
||||||
for (auto& event : *line.mutable_events()) {
|
for (auto& event : *line.mutable_events()) {
|
||||||
auto cur_node = absl::make_unique<EventNode>(&visitor, &event);
|
auto cur_node = absl::make_unique<EventNode>(&visitor, &line, &event);
|
||||||
while (!parent_nodes.empty()) {
|
while (!parent_nodes.empty()) {
|
||||||
EventNode* parent_node = parent_nodes.back();
|
EventNode* parent_node = parent_nodes.back();
|
||||||
if (cur_node->IsNestedIn(parent_node)) {
|
if (parent_node->GetEventVisitor().GetTimespan().Includes(
|
||||||
|
cur_node->GetEventVisitor().GetTimespan())) {
|
||||||
parent_node->AddChild(cur_node.get());
|
parent_node->AddChild(cur_node.get());
|
||||||
break;
|
break;
|
||||||
} else {
|
} else {
|
||||||
@ -357,12 +355,8 @@ void EventForest::CreateVirtualEventsForHostTrainingLoop() {
|
|||||||
if (!iter_num) continue;
|
if (!iter_num) continue;
|
||||||
EventNode*& virtual_event_node = virtual_event_node_map[step_id][iter_num];
|
EventNode*& virtual_event_node = virtual_event_node_map[step_id][iter_num];
|
||||||
if (!virtual_event_node) {
|
if (!virtual_event_node) {
|
||||||
std::unique_ptr<XEvent> new_virtual_event =
|
auto new_virtual_event_node =
|
||||||
CreateVirtualEvent(*step_id_stat, *iter_num_stat);
|
absl::make_unique<EventNode>(*executor_event_node);
|
||||||
auto new_virtual_event_node = absl::make_unique<EventNode>(
|
|
||||||
&executor_event_node->GetPlaneVisitor(), new_virtual_event.get());
|
|
||||||
// virtual_event_container_ keeps new_virtual_event alive.
|
|
||||||
virtual_event_container_.push_back(std::move(new_virtual_event));
|
|
||||||
virtual_event_node = new_virtual_event_node.get();
|
virtual_event_node = new_virtual_event_node.get();
|
||||||
// event_node_map_ keeps new_virtual_event_node alive.
|
// event_node_map_ keeps new_virtual_event_node alive.
|
||||||
event_node_map_[HostEventType::kHostTrainingLoopIteration].push_back(
|
event_node_map_[HostEventType::kHostTrainingLoopIteration].push_back(
|
||||||
@ -380,12 +374,8 @@ void EventForest::CreateVirtualEventsForAsyncExecutor() {
|
|||||||
for (auto& eager_kernel_execute_event_node :
|
for (auto& eager_kernel_execute_event_node :
|
||||||
*eager_kernel_execute_event_node_list) {
|
*eager_kernel_execute_event_node_list) {
|
||||||
if (HasFunctionRun(eager_kernel_execute_event_node.get())) {
|
if (HasFunctionRun(eager_kernel_execute_event_node.get())) {
|
||||||
auto new_virtual_event = absl::make_unique<XEvent>();
|
auto new_virtual_event_node =
|
||||||
auto new_virtual_event_node = absl::make_unique<EventNode>(
|
absl::make_unique<EventNode>(*eager_kernel_execute_event_node);
|
||||||
&eager_kernel_execute_event_node->GetPlaneVisitor(),
|
|
||||||
new_virtual_event.get());
|
|
||||||
// virtual_event_container_ keeps new_virtual_event alive.
|
|
||||||
virtual_event_container_.push_back(std::move(new_virtual_event));
|
|
||||||
virtual_event_node = new_virtual_event_node.get();
|
virtual_event_node = new_virtual_event_node.get();
|
||||||
// event_node_map_ keeps new_virtual_event_node alive.
|
// event_node_map_ keeps new_virtual_event_node alive.
|
||||||
event_node_map_[HostEventType::kAsyncExecutorTraceContext].push_back(
|
event_node_map_[HostEventType::kAsyncExecutorTraceContext].push_back(
|
||||||
|
@ -47,12 +47,10 @@ struct InterThreadConnectInfo {
|
|||||||
// pointers, a tree of EventNode is formed.
|
// pointers, a tree of EventNode is formed.
|
||||||
class EventNode {
|
class EventNode {
|
||||||
public:
|
public:
|
||||||
// REQUIRED: visitor and event should not be nullptr.
|
// REQUIRED: all inputs should not be nullptr.
|
||||||
explicit EventNode(const XPlaneVisitor* visitor, XEvent* event)
|
EventNode(const XPlaneVisitor* plane, XLine* raw_line, XEvent* raw_event);
|
||||||
: visitor_(visitor), event_(event) {
|
|
||||||
DCHECK(visitor);
|
EventNode(const EventNode& event_node);
|
||||||
DCHECK(event);
|
|
||||||
}
|
|
||||||
|
|
||||||
EventNode* GetParent() const { return parent_; }
|
EventNode* GetParent() const { return parent_; }
|
||||||
|
|
||||||
@ -70,9 +68,9 @@ class EventNode {
|
|||||||
// Sets group_id for this node and its descendants.
|
// Sets group_id for this node and its descendants.
|
||||||
void PropagateGroupId(int64 group_id);
|
void PropagateGroupId(int64 group_id);
|
||||||
|
|
||||||
const XPlaneVisitor& GetPlaneVisitor() const { return *visitor_; }
|
const XPlaneVisitor& GetPlaneVisitor() const { return *plane_; }
|
||||||
|
|
||||||
const XEvent& GetEvent() const { return *event_; }
|
const XEventVisitor& GetEventVisitor() const { return visitor_; }
|
||||||
|
|
||||||
const XStat* GetContextStat(int64 stat_type) const;
|
const XStat* GetContextStat(int64 stat_type) const;
|
||||||
|
|
||||||
@ -89,8 +87,10 @@ class EventNode {
|
|||||||
EventNode* FindParent(int64 event_type);
|
EventNode* FindParent(int64 event_type);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const XPlaneVisitor* visitor_;
|
const XPlaneVisitor* plane_;
|
||||||
XEvent* event_;
|
XEventVisitor visitor_;
|
||||||
|
XLine* raw_line_;
|
||||||
|
XEvent* raw_event_;
|
||||||
EventNode* parent_ = nullptr;
|
EventNode* parent_ = nullptr;
|
||||||
std::vector<EventNode*> children_;
|
std::vector<EventNode*> children_;
|
||||||
absl::optional<int64> group_id_;
|
absl::optional<int64> group_id_;
|
||||||
@ -100,8 +100,6 @@ using EventNodeMap =
|
|||||||
absl::flat_hash_map<int64 /*event_type*/,
|
absl::flat_hash_map<int64 /*event_type*/,
|
||||||
std::vector<std::unique_ptr<EventNode>>>;
|
std::vector<std::unique_ptr<EventNode>>>;
|
||||||
|
|
||||||
using VirtualEventContainer = std::vector<std::unique_ptr<XEvent>>;
|
|
||||||
|
|
||||||
using EventGroupNameMap = absl::flat_hash_map<int64 /*group_id*/, std::string>;
|
using EventGroupNameMap = absl::flat_hash_map<int64 /*group_id*/, std::string>;
|
||||||
|
|
||||||
// Creates a forest of EventNode by stitching events in space using the nesting
|
// Creates a forest of EventNode by stitching events in space using the nesting
|
||||||
@ -141,20 +139,19 @@ class EventForest {
|
|||||||
// Sets the is_eager stat to true for the eagerly executed CPU TF op events.
|
// Sets the is_eager stat to true for the eagerly executed CPU TF op events.
|
||||||
void MarkEagerlyExecutedCpuTfOps();
|
void MarkEagerlyExecutedCpuTfOps();
|
||||||
|
|
||||||
// Create virtual events of HostEventType::kHostTrainingLoopIteration and
|
// Create virtual events of HostEventType::kHostTrainingLoopIteration. A
|
||||||
// event nodes for them. A virtual event is created for each iteration of the
|
// virtual event is created for each iteration of the host training loop and
|
||||||
// host training loop and connected to the
|
// connected to the HostEventType::kExecutorStateProcess events of the
|
||||||
// HostEventType::kExecutorStateProcess event nodes of the iteration.
|
// iteration.
|
||||||
void CreateVirtualEventsForHostTrainingLoop();
|
void CreateVirtualEventsForHostTrainingLoop();
|
||||||
|
|
||||||
// Create virutal events of HostEventType::kAsyncExecutorTraceContext and
|
// Create virutal events of HostEventType::kAsyncExecutorTraceContext. A
|
||||||
// event nodes for them. A virtual event is created for every FunctionRun and
|
// virtual event is created for every FunctionRun and the following eager ops
|
||||||
// the following eager ops (e.g., for Keras callback).
|
// (e.g., for Keras callback).
|
||||||
void CreateVirtualEventsForAsyncExecutor();
|
void CreateVirtualEventsForAsyncExecutor();
|
||||||
|
|
||||||
EventNodeMap event_node_map_;
|
EventNodeMap event_node_map_;
|
||||||
std::vector<XPlaneVisitor> visitors_;
|
std::vector<XPlaneVisitor> visitors_;
|
||||||
VirtualEventContainer virtual_event_container_;
|
|
||||||
EventGroupNameMap event_group_name_map_;
|
EventGroupNameMap event_group_name_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user