Refactor the code by (1) using EventVisitor in EventNode and (2) copying the first child EventNode when creating the virtual EventNode instead creating a new XEvent.

PiperOrigin-RevId: 313671525
Change-Id: I421f08025c49a9deda5231184d287535486ac13b
This commit is contained in:
Jiho Choi 2020-05-28 15:30:46 -07:00 committed by TensorFlower Gardener
parent d29f1d6fe6
commit 1e22a99527
2 changed files with 46 additions and 59 deletions

View File

@ -103,16 +103,6 @@ int64 GetEventType(const XPlaneVisitor& visitor, const XEvent& event) {
} }
} }
const XStat* GetStat(const XPlaneVisitor& visitor, const XEvent& event,
int64 stat_type) {
for (const auto& stat : event.stats()) {
if (visitor.GetStatType(stat) == stat_type) {
return &stat;
}
}
return nullptr;
}
void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) { void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
AddOrUpdateIntStat(*visitor.GetStatMetadataId(StatType::kGroupId), group_id, AddOrUpdateIntStat(*visitor.GetStatMetadataId(StatType::kGroupId), group_id,
event); event);
@ -146,8 +136,7 @@ bool NeedsVirtualEventsForAsyncExecutor(
bool HasFunctionRun(EventNode* event_node) { bool HasFunctionRun(EventNode* event_node) {
for (EventNode* child : event_node->GetChildren()) { for (EventNode* child : event_node->GetChildren()) {
if (child->GetPlaneVisitor().GetEventType(child->GetEvent()) == if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
HostEventType::kFunctionRun) {
return true; return true;
} }
} }
@ -156,8 +145,21 @@ bool HasFunctionRun(EventNode* event_node) {
} // namespace } // namespace
EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
XEvent* raw_event)
: plane_(plane),
visitor_(plane, raw_line, raw_event),
raw_line_(raw_line),
raw_event_(raw_event) {}
EventNode::EventNode(const EventNode& event_node)
: plane_(event_node.plane_),
visitor_(event_node.plane_, event_node.raw_line_, event_node.raw_event_),
raw_line_(event_node.raw_line_),
raw_event_(event_node.raw_event_) {}
const XStat* EventNode::GetContextStat(int64 stat_type) const { const XStat* EventNode::GetContextStat(int64 stat_type) const {
if (const XStat* stat = GetStat(*visitor_, *event_, stat_type)) { if (const XStat* stat = visitor_.GetStats(stat_type)) {
return stat; return stat;
} else if (parent_) { } else if (parent_) {
return parent_->GetContextStat(stat_type); return parent_->GetContextStat(stat_type);
@ -168,7 +170,7 @@ const XStat* EventNode::GetContextStat(int64 stat_type) const {
std::string EventNode::GetGroupName() const { std::string EventNode::GetGroupName() const {
std::vector<std::string> name_parts; std::vector<std::string> name_parts;
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) { if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
XStatVisitor stat(visitor_, graph_type_stat); XStatVisitor stat(plane_, graph_type_stat);
name_parts.push_back(stat.ToString()); name_parts.push_back(stat.ToString());
} }
int64 step_num = group_id_.value_or(0); int64 step_num = group_id_.value_or(0);
@ -184,7 +186,7 @@ std::string EventNode::GetGroupName() const {
void EventNode::PropagateGroupId(int64 group_id) { void EventNode::PropagateGroupId(int64 group_id) {
group_id_ = group_id; group_id_ = group_id;
SetGroupId(*visitor_, group_id, event_); SetGroupId(*plane_, group_id, raw_event_);
for (const auto& child : children_) { for (const auto& child : children_) {
// Skip if it already belongs to a group. Some nodes may be added multiple // Skip if it already belongs to a group. Some nodes may be added multiple
// times as child (e.g., sometimes async ops are executed synchronously and // times as child (e.g., sometimes async ops are executed synchronously and
@ -196,13 +198,13 @@ void EventNode::PropagateGroupId(int64 group_id) {
} }
void EventNode::AddStepName(absl::string_view step_name) { void EventNode::AddStepName(absl::string_view step_name) {
AddOrUpdateStrStat(*visitor_->GetStatMetadataId(StatType::kStepName), AddOrUpdateStrStat(*plane_->GetStatMetadataId(StatType::kStepName), step_name,
step_name, event_); raw_event_);
} }
void EventNode::SetIsEager(bool is_eager) { void EventNode::SetIsEager(bool is_eager) {
AddOrUpdateIntStat(*visitor_->GetStatMetadataId(StatType::kIsEager), AddOrUpdateIntStat(*plane_->GetStatMetadataId(StatType::kIsEager),
is_eager ? 1 : 0, event_); is_eager ? 1 : 0, raw_event_);
} }
bool EventNode::IsEager() { bool EventNode::IsEager() {
@ -213,14 +215,9 @@ bool EventNode::IsEager() {
FindParent(HostEventType::kEagerKernelExecute) != nullptr; FindParent(HostEventType::kEagerKernelExecute) != nullptr;
} }
bool EventNode::IsNestedIn(EventNode* parent) {
return parent && IsNested(GetEvent(), parent->GetEvent());
}
EventNode* EventNode::FindParent(int64 event_type) { EventNode* EventNode::FindParent(int64 event_type) {
if (parent_) { if (parent_) {
if (GetEventType(parent_->GetPlaneVisitor(), parent_->GetEvent()) == if (parent_->GetEventVisitor().Type() == event_type) {
event_type) {
return parent_; return parent_;
} }
return parent_->FindParent(event_type); return parent_->FindParent(event_type);
@ -233,10 +230,11 @@ void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
for (auto& line : *plane->mutable_lines()) { for (auto& line : *plane->mutable_lines()) {
std::vector<EventNode*> parent_nodes; std::vector<EventNode*> parent_nodes;
for (auto& event : *line.mutable_events()) { for (auto& event : *line.mutable_events()) {
auto cur_node = absl::make_unique<EventNode>(&visitor, &event); auto cur_node = absl::make_unique<EventNode>(&visitor, &line, &event);
while (!parent_nodes.empty()) { while (!parent_nodes.empty()) {
EventNode* parent_node = parent_nodes.back(); EventNode* parent_node = parent_nodes.back();
if (cur_node->IsNestedIn(parent_node)) { if (parent_node->GetEventVisitor().GetTimespan().Includes(
cur_node->GetEventVisitor().GetTimespan())) {
parent_node->AddChild(cur_node.get()); parent_node->AddChild(cur_node.get());
break; break;
} else { } else {
@ -357,12 +355,8 @@ void EventForest::CreateVirtualEventsForHostTrainingLoop() {
if (!iter_num) continue; if (!iter_num) continue;
EventNode*& virtual_event_node = virtual_event_node_map[step_id][iter_num]; EventNode*& virtual_event_node = virtual_event_node_map[step_id][iter_num];
if (!virtual_event_node) { if (!virtual_event_node) {
std::unique_ptr<XEvent> new_virtual_event = auto new_virtual_event_node =
CreateVirtualEvent(*step_id_stat, *iter_num_stat); absl::make_unique<EventNode>(*executor_event_node);
auto new_virtual_event_node = absl::make_unique<EventNode>(
&executor_event_node->GetPlaneVisitor(), new_virtual_event.get());
// virtual_event_container_ keeps new_virtual_event alive.
virtual_event_container_.push_back(std::move(new_virtual_event));
virtual_event_node = new_virtual_event_node.get(); virtual_event_node = new_virtual_event_node.get();
// event_node_map_ keeps new_virtual_event_node alive. // event_node_map_ keeps new_virtual_event_node alive.
event_node_map_[HostEventType::kHostTrainingLoopIteration].push_back( event_node_map_[HostEventType::kHostTrainingLoopIteration].push_back(
@ -380,12 +374,8 @@ void EventForest::CreateVirtualEventsForAsyncExecutor() {
for (auto& eager_kernel_execute_event_node : for (auto& eager_kernel_execute_event_node :
*eager_kernel_execute_event_node_list) { *eager_kernel_execute_event_node_list) {
if (HasFunctionRun(eager_kernel_execute_event_node.get())) { if (HasFunctionRun(eager_kernel_execute_event_node.get())) {
auto new_virtual_event = absl::make_unique<XEvent>(); auto new_virtual_event_node =
auto new_virtual_event_node = absl::make_unique<EventNode>( absl::make_unique<EventNode>(*eager_kernel_execute_event_node);
&eager_kernel_execute_event_node->GetPlaneVisitor(),
new_virtual_event.get());
// virtual_event_container_ keeps new_virtual_event alive.
virtual_event_container_.push_back(std::move(new_virtual_event));
virtual_event_node = new_virtual_event_node.get(); virtual_event_node = new_virtual_event_node.get();
// event_node_map_ keeps new_virtual_event_node alive. // event_node_map_ keeps new_virtual_event_node alive.
event_node_map_[HostEventType::kAsyncExecutorTraceContext].push_back( event_node_map_[HostEventType::kAsyncExecutorTraceContext].push_back(

View File

@ -47,12 +47,10 @@ struct InterThreadConnectInfo {
// pointers, a tree of EventNode is formed. // pointers, a tree of EventNode is formed.
class EventNode { class EventNode {
public: public:
// REQUIRED: visitor and event should not be nullptr. // REQUIRED: all inputs should not be nullptr.
explicit EventNode(const XPlaneVisitor* visitor, XEvent* event) EventNode(const XPlaneVisitor* plane, XLine* raw_line, XEvent* raw_event);
: visitor_(visitor), event_(event) {
DCHECK(visitor); EventNode(const EventNode& event_node);
DCHECK(event);
}
EventNode* GetParent() const { return parent_; } EventNode* GetParent() const { return parent_; }
@ -70,9 +68,9 @@ class EventNode {
// Sets group_id for this node and its descendants. // Sets group_id for this node and its descendants.
void PropagateGroupId(int64 group_id); void PropagateGroupId(int64 group_id);
const XPlaneVisitor& GetPlaneVisitor() const { return *visitor_; } const XPlaneVisitor& GetPlaneVisitor() const { return *plane_; }
const XEvent& GetEvent() const { return *event_; } const XEventVisitor& GetEventVisitor() const { return visitor_; }
const XStat* GetContextStat(int64 stat_type) const; const XStat* GetContextStat(int64 stat_type) const;
@ -89,8 +87,10 @@ class EventNode {
EventNode* FindParent(int64 event_type); EventNode* FindParent(int64 event_type);
private: private:
const XPlaneVisitor* visitor_; const XPlaneVisitor* plane_;
XEvent* event_; XEventVisitor visitor_;
XLine* raw_line_;
XEvent* raw_event_;
EventNode* parent_ = nullptr; EventNode* parent_ = nullptr;
std::vector<EventNode*> children_; std::vector<EventNode*> children_;
absl::optional<int64> group_id_; absl::optional<int64> group_id_;
@ -100,8 +100,6 @@ using EventNodeMap =
absl::flat_hash_map<int64 /*event_type*/, absl::flat_hash_map<int64 /*event_type*/,
std::vector<std::unique_ptr<EventNode>>>; std::vector<std::unique_ptr<EventNode>>>;
using VirtualEventContainer = std::vector<std::unique_ptr<XEvent>>;
using EventGroupNameMap = absl::flat_hash_map<int64 /*group_id*/, std::string>; using EventGroupNameMap = absl::flat_hash_map<int64 /*group_id*/, std::string>;
// Creates a forest of EventNode by stitching events in space using the nesting // Creates a forest of EventNode by stitching events in space using the nesting
@ -141,20 +139,19 @@ class EventForest {
// Sets the is_eager stat to true for the eagerly executed CPU TF op events. // Sets the is_eager stat to true for the eagerly executed CPU TF op events.
void MarkEagerlyExecutedCpuTfOps(); void MarkEagerlyExecutedCpuTfOps();
// Create virtual events of HostEventType::kHostTrainingLoopIteration and // Create virtual events of HostEventType::kHostTrainingLoopIteration. A
// event nodes for them. A virtual event is created for each iteration of the // virtual event is created for each iteration of the host training loop and
// host training loop and connected to the // connected to the HostEventType::kExecutorStateProcess events of the
// HostEventType::kExecutorStateProcess event nodes of the iteration. // iteration.
void CreateVirtualEventsForHostTrainingLoop(); void CreateVirtualEventsForHostTrainingLoop();
// Create virutal events of HostEventType::kAsyncExecutorTraceContext and // Create virutal events of HostEventType::kAsyncExecutorTraceContext. A
// event nodes for them. A virtual event is created for every FunctionRun and // virtual event is created for every FunctionRun and the following eager ops
// the following eager ops (e.g., for Keras callback). // (e.g., for Keras callback).
void CreateVirtualEventsForAsyncExecutor(); void CreateVirtualEventsForAsyncExecutor();
EventNodeMap event_node_map_; EventNodeMap event_node_map_;
std::vector<XPlaneVisitor> visitors_; std::vector<XPlaneVisitor> visitors_;
VirtualEventContainer virtual_event_container_;
EventGroupNameMap event_group_name_map_; EventGroupNameMap event_group_name_map_;
}; };