Add grouping support for inference batching profiles (group_id is assigned for each request and batch separately) and allow EventNode to have multiple parents.

PiperOrigin-RevId: 331631036
Change-Id: I2fc242388b20a6af95345c1233980c90591b7ca5
This commit is contained in:
Jiho Choi 2020-09-14 14:37:02 -07:00 committed by TensorFlower Gardener
parent e1a7896aaf
commit 93f43920d5
6 changed files with 152 additions and 47 deletions

View File

@ -323,6 +323,7 @@ tf_cc_test(
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/profiler/lib:connected_traceme",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include <iterator>
#include <map>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include <vector>
@ -27,6 +28,7 @@ limitations under the License.
#include "absl/algorithm/container.h"
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
@ -55,36 +57,34 @@ void CreateStatMetadata(XPlane* plane) {
// Returns event type if it is a KernelLaunch or KernelExecute event.
absl::optional<int64> GetKernelEventType(bool is_host_plane,
const XPlaneVisitor& visitor,
const XEvent& event) {
for (const auto& stat : event.stats()) {
if (visitor.GetStatType(stat) == StatType::kCorrelationId) {
return is_host_plane ? HostEventType::kKernelLaunch
: HostEventType::kKernelExecute;
}
const EventNode& event) {
if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
return is_host_plane ? HostEventType::kKernelLaunch
: HostEventType::kKernelExecute;
}
return absl::nullopt;
}
Category GetTfEventCategory(const XPlaneVisitor& visitor, const XEvent& event) {
TfOp tf_op =
ParseTfOpFullname(visitor.GetEventMetadata(event.metadata_id())->name());
return tf_op.category;
}
int64 GetEventType(bool is_host_plane, const XPlaneVisitor& visitor,
const XEvent& event) {
if (absl::optional<int64> event_type = visitor.GetEventType(event)) {
int64 GetEventType(bool is_host_plane, const EventNode& event) {
if (absl::optional<int64> event_type = event.GetEventVisitor().Type()) {
return *event_type;
} else if (absl::optional<int64> kernel_event_type =
GetKernelEventType(is_host_plane, visitor, event)) {
GetKernelEventType(is_host_plane, event)) {
// KernelLaunch and KernelExecute event types are not supported by
// XPlaneVisitor and should be checked separately.
// TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
// KernelExecute event types.
return *kernel_event_type;
} else {
Category category = GetTfEventCategory(visitor, event);
absl::string_view name = event.GetEventVisitor().Name();
// Legacy event names appended with arguments.
if (absl::StartsWith(name, "BatchingSession")) {
return HostEventType::kBatchingSession;
} else if (absl::StartsWith(name, "ProcessBatch")) {
return HostEventType::kProcessBatch;
}
// TF op names.
Category category = ParseTfOpFullname(name).category;
switch (category) {
case Category::kTensorFlow:
return HostEventType::kTfOpRun;
@ -104,9 +104,8 @@ void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
void SetContextGroup(EventNode* event, ContextGroupMap* context_groups) {
auto producer = event->GetProducerContext();
if (producer.has_value()) {
DCHECK_EQ(((*context_groups)[producer->type][producer->id]).producer,
nullptr);
((*context_groups)[producer->type][producer->id]).producer = event;
((*context_groups)[producer->type][producer->id])
.producers.push_back(event);
}
auto consumer = event->GetConsumerContext();
if (consumer.has_value()) {
@ -119,7 +118,7 @@ void ConnectContextGroups(const ContextGroupMap& context_groups) {
for (auto& type_id_group : context_groups) {
for (auto& id_group : type_id_group.second) {
const ContextGroup& group = id_group.second;
if (EventNode* parent = group.producer) {
for (EventNode* parent : group.producers) {
for (EventNode* child : group.consumers) {
parent->AddChild(child);
}
@ -153,9 +152,15 @@ bool IsImplicitRootEvent(const XEventVisitor& event) {
kImplicitRootEvents->contains(*event.Type());
}
void ProcessRootEvent(int64 group_id, EventNode* root_event,
void ProcessRootEvent(int64 group_id, bool set_step_name, EventNode* root_event,
GroupMetadataMap* group_metadata_map) {
root_event->PropagateGroupId(group_id);
if (!set_step_name) {
// Step names are not necessary for inference profiles but add group_id to
// group_metadata_map to count the number of groups.
group_metadata_map->emplace(group_id, GroupMetadata());
return;
}
std::string group_name = root_event->GetGroupName();
// TODO(jihochoi): change event name instead.
if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
@ -258,14 +263,42 @@ bool IsLegacyRootEvent(const XEventVisitor& event) {
return event.Type().has_value() && kRootEvents->contains(*event.Type());
}
using Comparator = std::function<bool(const EventNode*)>;
const EventNode* FindParentWithComparator(const Comparator& comparator,
const EventNode* node,
bool include_self) {
std::queue<const EventNode*> nodes;
absl::flat_hash_set<const EventNode*> seen = {node};
if (include_self) {
nodes.push(node);
} else {
for (const EventNode* parent : node->GetParents()) {
nodes.push(parent);
seen.insert(parent);
}
}
while (!nodes.empty()) {
const EventNode* node = nodes.front();
nodes.pop();
if (comparator(node)) return node;
for (const EventNode* parent : node->GetParents()) {
if (seen.contains(parent)) continue;
nodes.push(parent);
seen.insert(parent);
}
}
return nullptr;
}
// Returns true if none of its ancestors is a root event.
bool IsTopRoot(const EventNode* event) {
// If it is already grouped, it is not a top root.
if (event->GetGroupId().has_value()) return false;
for (EventNode* cur = event->GetParent(); cur != nullptr;
cur = cur->GetParent()) {
if (cur->IsRoot()) return false;
}
return true;
const EventNode* root_parent = FindParentWithComparator(
[](const EventNode* node) { return node->IsRoot(); }, event,
/*include_self=*/false);
return root_parent == nullptr;
}
void SortEventList(EventList* event_list) {
@ -347,10 +380,20 @@ EventNode::EventNode(const EventNode& event_node)
event_node.raw_event_) {}
absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
for (const EventNode* node = this; node != nullptr; node = node->parent_) {
std::queue<const EventNode*> nodes;
absl::flat_hash_set<const EventNode*> seen = {this};
nodes.push(this);
while (!nodes.empty()) {
const EventNode* node = nodes.front();
nodes.pop();
if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
return stat;
}
for (const EventNode* parent : node->GetParents()) {
if (seen.contains(parent)) continue;
nodes.push(parent);
seen.insert(parent);
}
}
return absl::nullopt;
}
@ -383,7 +426,7 @@ void EventNode::PropagateGroupId(int64 group_id) {
// their nodes are added as child both in ConnectIntraThread and
// ConnectInterThread).
if (child->GetGroupId()) continue;
child->PropagateGroupId(*group_id_);
child->PropagateGroupId(group_id);
}
}
@ -406,15 +449,11 @@ bool EventNode::IsEager() {
}
const EventNode* EventNode::FindParent(int64 event_type) const {
absl::flat_hash_set<const EventNode*> seen;
const EventNode* node = this;
while (node) {
if (seen.contains(node)) break;
if (node->GetEventVisitor().Type() == event_type) return node;
seen.insert(node);
node = node->GetParent();
}
return nullptr;
return FindParentWithComparator(
[event_type](const EventNode* node) {
return node->GetEventVisitor().Type() == event_type;
},
this, /*include_self=*/true);
}
bool EventNode::StartsBefore(const EventNode& other) const {
@ -449,7 +488,7 @@ void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
}
parent_nodes.push_back(cur_node.get());
// event_node_map_ keeps cur_node alive.
event_node_map_[GetEventType(is_host_plane, visitor, event)].push_back(
event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
std::move(cur_node));
}
}
@ -513,10 +552,31 @@ void EventForest::ProcessLegacyRootEvents(
}
void EventForest::CreateEventGroup() {
// Handle inference batching profiles.
if (event_node_map_.contains(HostEventType::kProcessBatch)) {
for (const auto& process_batch_node :
event_node_map_[HostEventType::kProcessBatch]) {
ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
process_batch_node.get(), &group_metadata_map_);
}
HostEventType request_event_type =
event_node_map_.contains(HostEventType::kBatchingSession)
? HostEventType::kBatchingSession
: HostEventType::kSessionRun;
if (auto request_events =
gtl::FindOrNull(event_node_map_, request_event_type)) {
for (const auto& request_event : *request_events) {
ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
request_event.get(), &group_metadata_map_);
}
}
return;
}
// Create a group for each TF loop iteration in non-JAX profiles.
if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
for (EventNode* root_event : tf_loop_root_events_) {
ProcessRootEvent(next_group_id_++, root_event, &group_metadata_map_);
ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
&group_metadata_map_);
}
return;
}
@ -527,7 +587,8 @@ void EventForest::CreateEventGroup() {
if (IsTopRoot(root_event) &&
(!HasJaxEvent(event_node_map_) ||
!IsLegacyRootEvent(root_event->GetEventVisitor()))) {
ProcessRootEvent(next_group_id_++, root_event, &group_metadata_map_);
ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
&group_metadata_map_);
}
}
}
@ -697,7 +758,10 @@ void EventForest::ProcessTfDataEvents() {
absl::optional<XStatVisitor> element_id =
consume_event->GetEventVisitor().GetStat(StatType::kElementId);
if (!element_id.has_value()) continue;
EventNode* consume_iterator = consume_event->GetParent();
if (consume_event->GetParents().empty()) continue;
// consume_event is nested by consumer_iterator and does not have other
// parents.
EventNode* consume_iterator = consume_event->GetParents().at(0);
if (!consume_iterator ||
!IsDatasetOp(
ParseTfOpFullname(consume_iterator->GetEventVisitor().Name()))) {

View File

@ -58,13 +58,13 @@ class EventNode {
EventNode(const EventNode& event_node);
EventNode* GetParent() const { return parent_; }
const std::vector<EventNode*>& GetParents() const { return parents_; }
const std::vector<EventNode*>& GetChildren() const { return children_; }
void AddChild(EventNode* child) {
children_.push_back(child);
child->parent_ = this;
child->parents_.push_back(this);
}
absl::optional<int64> GetGroupId() const { return group_id_; }
@ -113,7 +113,7 @@ class EventNode {
XEventVisitor visitor_;
XLine* raw_line_;
XEvent* raw_event_;
EventNode* parent_ = nullptr;
std::vector<EventNode*> parents_;
std::vector<EventNode*> children_;
absl::optional<int64> group_id_;
absl::optional<ContextInfo> producer_context_;
@ -136,7 +136,7 @@ using GroupMetadataMap = absl::flat_hash_map<int64 /*group_id*/, GroupMetadata>;
using EventList = std::vector<EventNode*>;
struct ContextGroup {
EventNode* producer = nullptr;
std::vector<EventNode*> producers;
std::vector<EventNode*> consumers;
};

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "absl/types/optional.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/connected_traceme.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
#include "tensorflow/core/profiler/utils/xplane_builder.h"
@ -563,6 +564,39 @@ TEST(GroupEventsTest, WorkerTest) {
});
}
TEST(GroupEventsTest, BatchingSessionTest) {
constexpr absl::string_view kSchedule = "Schedule";
constexpr int64 kBatchContextType =
static_cast<int64>(ContextType::kSharedBatchScheduler);
constexpr int64 kBatchContextId = 123;
XSpace raw_space;
XPlane* raw_plane = raw_space.add_planes();
XPlaneBuilder plane(raw_plane);
plane.ReserveLines(1);
auto request_thread = plane.GetOrCreateLine(0);
// First request.
CreateXEvent(&plane, &request_thread, HostEventType::kBatchingSession, 0,
100);
CreateXEvent(&plane, &request_thread, kSchedule, 0, 100,
{{StatType::kProducerType, kBatchContextType},
{StatType::kProducerId, kBatchContextId}});
// Second request.
CreateXEvent(&plane, &request_thread, HostEventType::kBatchingSession, 200,
100);
CreateXEvent(&plane, &request_thread, kSchedule, 200, 100,
{{StatType::kProducerType, kBatchContextType},
{StatType::kProducerId, kBatchContextId}});
auto batch_thread = plane.GetOrCreateLine(0);
CreateXEvent(&plane, &batch_thread, HostEventType::kProcessBatch, 200, 100,
{{StatType::kConsumerType, kBatchContextType},
{StatType::kConsumerId, kBatchContextId}});
GroupMetadataMap group_metadata_map;
GroupTfEvents(&raw_space, &group_metadata_map);
EXPECT_EQ(group_metadata_map.size(), 3);
}
} // namespace
} // namespace profiler
} // namespace tensorflow

View File

@ -108,6 +108,9 @@ const HostEventTypeMap& GetHostEventTypeMap() {
{"MapAndBatchConsume", kMapAndBatchConsume},
{"ParseExampleProduce", kParseExampleProduce},
{"ParseExampleConsume", kParseExampleConsume},
// Batching related.
{"BatchingSession", kBatchingSession},
{"ProcessBatch", kProcessBatch},
// JAX related.
{"LocalExecutable::ExecuteOnLocalDevices", kExecuteOnLocalDevices},
// GPU related.

View File

@ -99,6 +99,9 @@ enum HostEventType {
kMapAndBatchConsume,
kParseExampleProduce,
kParseExampleConsume,
// Batching related.
kBatchingSession,
kProcessBatch,
// JAX related.
kExecuteOnLocalDevices,
// GPU related.