Refactor the code by (1) using EventVisitor in EventNode and (2) copying the first child EventNode when creating the virtual EventNode instead creating a new XEvent.

PiperOrigin-RevId: 313671525
Change-Id: I421f08025c49a9deda5231184d287535486ac13b
This commit is contained in:
Jiho Choi 2020-05-28 15:30:46 -07:00 committed by TensorFlower Gardener
parent d29f1d6fe6
commit 1e22a99527
2 changed files with 46 additions and 59 deletions

View File

@ -103,16 +103,6 @@ int64 GetEventType(const XPlaneVisitor& visitor, const XEvent& event) {
}
}
const XStat* GetStat(const XPlaneVisitor& visitor, const XEvent& event,
int64 stat_type) {
for (const auto& stat : event.stats()) {
if (visitor.GetStatType(stat) == stat_type) {
return &stat;
}
}
return nullptr;
}
void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
AddOrUpdateIntStat(*visitor.GetStatMetadataId(StatType::kGroupId), group_id,
event);
@ -146,8 +136,7 @@ bool NeedsVirtualEventsForAsyncExecutor(
bool HasFunctionRun(EventNode* event_node) {
for (EventNode* child : event_node->GetChildren()) {
if (child->GetPlaneVisitor().GetEventType(child->GetEvent()) ==
HostEventType::kFunctionRun) {
if (child->GetEventVisitor().Type() == HostEventType::kFunctionRun) {
return true;
}
}
@ -156,8 +145,21 @@ bool HasFunctionRun(EventNode* event_node) {
} // namespace
EventNode::EventNode(const XPlaneVisitor* plane, XLine* raw_line,
XEvent* raw_event)
: plane_(plane),
visitor_(plane, raw_line, raw_event),
raw_line_(raw_line),
raw_event_(raw_event) {}
EventNode::EventNode(const EventNode& event_node)
: plane_(event_node.plane_),
visitor_(event_node.plane_, event_node.raw_line_, event_node.raw_event_),
raw_line_(event_node.raw_line_),
raw_event_(event_node.raw_event_) {}
const XStat* EventNode::GetContextStat(int64 stat_type) const {
if (const XStat* stat = GetStat(*visitor_, *event_, stat_type)) {
if (const XStat* stat = visitor_.GetStats(stat_type)) {
return stat;
} else if (parent_) {
return parent_->GetContextStat(stat_type);
@ -168,7 +170,7 @@ const XStat* EventNode::GetContextStat(int64 stat_type) const {
std::string EventNode::GetGroupName() const {
std::vector<std::string> name_parts;
if (const XStat* graph_type_stat = GetContextStat(StatType::kGraphType)) {
XStatVisitor stat(visitor_, graph_type_stat);
XStatVisitor stat(plane_, graph_type_stat);
name_parts.push_back(stat.ToString());
}
int64 step_num = group_id_.value_or(0);
@ -184,7 +186,7 @@ std::string EventNode::GetGroupName() const {
void EventNode::PropagateGroupId(int64 group_id) {
group_id_ = group_id;
SetGroupId(*visitor_, group_id, event_);
SetGroupId(*plane_, group_id, raw_event_);
for (const auto& child : children_) {
// Skip if it already belongs to a group. Some nodes may be added multiple
// times as child (e.g., sometimes async ops are executed synchronously and
@ -196,13 +198,13 @@ void EventNode::PropagateGroupId(int64 group_id) {
}
void EventNode::AddStepName(absl::string_view step_name) {
AddOrUpdateStrStat(*visitor_->GetStatMetadataId(StatType::kStepName),
step_name, event_);
AddOrUpdateStrStat(*plane_->GetStatMetadataId(StatType::kStepName), step_name,
raw_event_);
}
void EventNode::SetIsEager(bool is_eager) {
AddOrUpdateIntStat(*visitor_->GetStatMetadataId(StatType::kIsEager),
is_eager ? 1 : 0, event_);
AddOrUpdateIntStat(*plane_->GetStatMetadataId(StatType::kIsEager),
is_eager ? 1 : 0, raw_event_);
}
bool EventNode::IsEager() {
@ -213,14 +215,9 @@ bool EventNode::IsEager() {
FindParent(HostEventType::kEagerKernelExecute) != nullptr;
}
bool EventNode::IsNestedIn(EventNode* parent) {
return parent && IsNested(GetEvent(), parent->GetEvent());
}
EventNode* EventNode::FindParent(int64 event_type) {
if (parent_) {
if (GetEventType(parent_->GetPlaneVisitor(), parent_->GetEvent()) ==
event_type) {
if (parent_->GetEventVisitor().Type() == event_type) {
return parent_;
}
return parent_->FindParent(event_type);
@ -233,10 +230,11 @@ void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
for (auto& line : *plane->mutable_lines()) {
std::vector<EventNode*> parent_nodes;
for (auto& event : *line.mutable_events()) {
auto cur_node = absl::make_unique<EventNode>(&visitor, &event);
auto cur_node = absl::make_unique<EventNode>(&visitor, &line, &event);
while (!parent_nodes.empty()) {
EventNode* parent_node = parent_nodes.back();
if (cur_node->IsNestedIn(parent_node)) {
if (parent_node->GetEventVisitor().GetTimespan().Includes(
cur_node->GetEventVisitor().GetTimespan())) {
parent_node->AddChild(cur_node.get());
break;
} else {
@ -357,12 +355,8 @@ void EventForest::CreateVirtualEventsForHostTrainingLoop() {
if (!iter_num) continue;
EventNode*& virtual_event_node = virtual_event_node_map[step_id][iter_num];
if (!virtual_event_node) {
std::unique_ptr<XEvent> new_virtual_event =
CreateVirtualEvent(*step_id_stat, *iter_num_stat);
auto new_virtual_event_node = absl::make_unique<EventNode>(
&executor_event_node->GetPlaneVisitor(), new_virtual_event.get());
// virtual_event_container_ keeps new_virtual_event alive.
virtual_event_container_.push_back(std::move(new_virtual_event));
auto new_virtual_event_node =
absl::make_unique<EventNode>(*executor_event_node);
virtual_event_node = new_virtual_event_node.get();
// event_node_map_ keeps new_virtual_event_node alive.
event_node_map_[HostEventType::kHostTrainingLoopIteration].push_back(
@ -380,12 +374,8 @@ void EventForest::CreateVirtualEventsForAsyncExecutor() {
for (auto& eager_kernel_execute_event_node :
*eager_kernel_execute_event_node_list) {
if (HasFunctionRun(eager_kernel_execute_event_node.get())) {
auto new_virtual_event = absl::make_unique<XEvent>();
auto new_virtual_event_node = absl::make_unique<EventNode>(
&eager_kernel_execute_event_node->GetPlaneVisitor(),
new_virtual_event.get());
// virtual_event_container_ keeps new_virtual_event alive.
virtual_event_container_.push_back(std::move(new_virtual_event));
auto new_virtual_event_node =
absl::make_unique<EventNode>(*eager_kernel_execute_event_node);
virtual_event_node = new_virtual_event_node.get();
// event_node_map_ keeps new_virtual_event_node alive.
event_node_map_[HostEventType::kAsyncExecutorTraceContext].push_back(

View File

@ -47,12 +47,10 @@ struct InterThreadConnectInfo {
// pointers, a tree of EventNode is formed.
class EventNode {
public:
// REQUIRED: visitor and event should not be nullptr.
explicit EventNode(const XPlaneVisitor* visitor, XEvent* event)
: visitor_(visitor), event_(event) {
DCHECK(visitor);
DCHECK(event);
}
// REQUIRED: all inputs should not be nullptr.
EventNode(const XPlaneVisitor* plane, XLine* raw_line, XEvent* raw_event);
EventNode(const EventNode& event_node);
EventNode* GetParent() const { return parent_; }
@ -70,9 +68,9 @@ class EventNode {
// Sets group_id for this node and its descendants.
void PropagateGroupId(int64 group_id);
const XPlaneVisitor& GetPlaneVisitor() const { return *visitor_; }
const XPlaneVisitor& GetPlaneVisitor() const { return *plane_; }
const XEvent& GetEvent() const { return *event_; }
const XEventVisitor& GetEventVisitor() const { return visitor_; }
const XStat* GetContextStat(int64 stat_type) const;
@ -89,8 +87,10 @@ class EventNode {
EventNode* FindParent(int64 event_type);
private:
const XPlaneVisitor* visitor_;
XEvent* event_;
const XPlaneVisitor* plane_;
XEventVisitor visitor_;
XLine* raw_line_;
XEvent* raw_event_;
EventNode* parent_ = nullptr;
std::vector<EventNode*> children_;
absl::optional<int64> group_id_;
@ -100,8 +100,6 @@ using EventNodeMap =
absl::flat_hash_map<int64 /*event_type*/,
std::vector<std::unique_ptr<EventNode>>>;
using VirtualEventContainer = std::vector<std::unique_ptr<XEvent>>;
using EventGroupNameMap = absl::flat_hash_map<int64 /*group_id*/, std::string>;
// Creates a forest of EventNode by stitching events in space using the nesting
@ -141,20 +139,19 @@ class EventForest {
// Sets the is_eager stat to true for the eagerly executed CPU TF op events.
void MarkEagerlyExecutedCpuTfOps();
// Create virtual events of HostEventType::kHostTrainingLoopIteration and
// event nodes for them. A virtual event is created for each iteration of the
// host training loop and connected to the
// HostEventType::kExecutorStateProcess event nodes of the iteration.
// Create virtual events of HostEventType::kHostTrainingLoopIteration. A
// virtual event is created for each iteration of the host training loop and
// connected to the HostEventType::kExecutorStateProcess events of the
// iteration.
void CreateVirtualEventsForHostTrainingLoop();
// Create virutal events of HostEventType::kAsyncExecutorTraceContext and
// event nodes for them. A virtual event is created for every FunctionRun and
// the following eager ops (e.g., for Keras callback).
// Create virutal events of HostEventType::kAsyncExecutorTraceContext. A
// virtual event is created for every FunctionRun and the following eager ops
// (e.g., for Keras callback).
void CreateVirtualEventsForAsyncExecutor();
EventNodeMap event_node_map_;
std::vector<XPlaneVisitor> visitors_;
VirtualEventContainer virtual_event_container_;
EventGroupNameMap event_group_name_map_;
};