Add grouping support for inference batching profiles (group_id is assigned for each request and batch separately) and allow EventNode to have multiple parents.
PiperOrigin-RevId: 331631036 Change-Id: I2fc242388b20a6af95345c1233980c90591b7ca5
This commit is contained in:
parent
e1a7896aaf
commit
93f43920d5
@ -323,6 +323,7 @@ tf_cc_test(
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/profiler/lib:connected_traceme",
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
|
||||
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <iterator>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <queue>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -27,6 +28,7 @@ limitations under the License.
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
@ -55,36 +57,34 @@ void CreateStatMetadata(XPlane* plane) {
|
||||
|
||||
// Returns event type if it is a KernelLaunch or KernelExecute event.
|
||||
absl::optional<int64> GetKernelEventType(bool is_host_plane,
|
||||
const XPlaneVisitor& visitor,
|
||||
const XEvent& event) {
|
||||
for (const auto& stat : event.stats()) {
|
||||
if (visitor.GetStatType(stat) == StatType::kCorrelationId) {
|
||||
return is_host_plane ? HostEventType::kKernelLaunch
|
||||
: HostEventType::kKernelExecute;
|
||||
}
|
||||
const EventNode& event) {
|
||||
if (event.GetEventVisitor().GetStat(StatType::kCorrelationId).has_value()) {
|
||||
return is_host_plane ? HostEventType::kKernelLaunch
|
||||
: HostEventType::kKernelExecute;
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
Category GetTfEventCategory(const XPlaneVisitor& visitor, const XEvent& event) {
|
||||
TfOp tf_op =
|
||||
ParseTfOpFullname(visitor.GetEventMetadata(event.metadata_id())->name());
|
||||
return tf_op.category;
|
||||
}
|
||||
|
||||
int64 GetEventType(bool is_host_plane, const XPlaneVisitor& visitor,
|
||||
const XEvent& event) {
|
||||
if (absl::optional<int64> event_type = visitor.GetEventType(event)) {
|
||||
int64 GetEventType(bool is_host_plane, const EventNode& event) {
|
||||
if (absl::optional<int64> event_type = event.GetEventVisitor().Type()) {
|
||||
return *event_type;
|
||||
} else if (absl::optional<int64> kernel_event_type =
|
||||
GetKernelEventType(is_host_plane, visitor, event)) {
|
||||
GetKernelEventType(is_host_plane, event)) {
|
||||
// KernelLaunch and KernelExecute event types are not supported by
|
||||
// XPlaneVisitor and should be checked separately.
|
||||
// TODO(b/148346217): Make XPlaneVisitor support KernelLaunch and
|
||||
// KernelExecute event types.
|
||||
return *kernel_event_type;
|
||||
} else {
|
||||
Category category = GetTfEventCategory(visitor, event);
|
||||
absl::string_view name = event.GetEventVisitor().Name();
|
||||
// Legacy event names appended with arguments.
|
||||
if (absl::StartsWith(name, "BatchingSession")) {
|
||||
return HostEventType::kBatchingSession;
|
||||
} else if (absl::StartsWith(name, "ProcessBatch")) {
|
||||
return HostEventType::kProcessBatch;
|
||||
}
|
||||
// TF op names.
|
||||
Category category = ParseTfOpFullname(name).category;
|
||||
switch (category) {
|
||||
case Category::kTensorFlow:
|
||||
return HostEventType::kTfOpRun;
|
||||
@ -104,9 +104,8 @@ void SetGroupId(const XPlaneVisitor& visitor, int64 group_id, XEvent* event) {
|
||||
void SetContextGroup(EventNode* event, ContextGroupMap* context_groups) {
|
||||
auto producer = event->GetProducerContext();
|
||||
if (producer.has_value()) {
|
||||
DCHECK_EQ(((*context_groups)[producer->type][producer->id]).producer,
|
||||
nullptr);
|
||||
((*context_groups)[producer->type][producer->id]).producer = event;
|
||||
((*context_groups)[producer->type][producer->id])
|
||||
.producers.push_back(event);
|
||||
}
|
||||
auto consumer = event->GetConsumerContext();
|
||||
if (consumer.has_value()) {
|
||||
@ -119,7 +118,7 @@ void ConnectContextGroups(const ContextGroupMap& context_groups) {
|
||||
for (auto& type_id_group : context_groups) {
|
||||
for (auto& id_group : type_id_group.second) {
|
||||
const ContextGroup& group = id_group.second;
|
||||
if (EventNode* parent = group.producer) {
|
||||
for (EventNode* parent : group.producers) {
|
||||
for (EventNode* child : group.consumers) {
|
||||
parent->AddChild(child);
|
||||
}
|
||||
@ -153,9 +152,15 @@ bool IsImplicitRootEvent(const XEventVisitor& event) {
|
||||
kImplicitRootEvents->contains(*event.Type());
|
||||
}
|
||||
|
||||
void ProcessRootEvent(int64 group_id, EventNode* root_event,
|
||||
void ProcessRootEvent(int64 group_id, bool set_step_name, EventNode* root_event,
|
||||
GroupMetadataMap* group_metadata_map) {
|
||||
root_event->PropagateGroupId(group_id);
|
||||
if (!set_step_name) {
|
||||
// Step names are not necessary for inference profiles but add group_id to
|
||||
// group_metadata_map to count the number of groups.
|
||||
group_metadata_map->emplace(group_id, GroupMetadata());
|
||||
return;
|
||||
}
|
||||
std::string group_name = root_event->GetGroupName();
|
||||
// TODO(jihochoi): change event name instead.
|
||||
if (!IsImplicitRootEvent(root_event->GetEventVisitor())) {
|
||||
@ -258,14 +263,42 @@ bool IsLegacyRootEvent(const XEventVisitor& event) {
|
||||
return event.Type().has_value() && kRootEvents->contains(*event.Type());
|
||||
}
|
||||
|
||||
using Comparator = std::function<bool(const EventNode*)>;
|
||||
|
||||
const EventNode* FindParentWithComparator(const Comparator& comparator,
|
||||
const EventNode* node,
|
||||
bool include_self) {
|
||||
std::queue<const EventNode*> nodes;
|
||||
absl::flat_hash_set<const EventNode*> seen = {node};
|
||||
if (include_self) {
|
||||
nodes.push(node);
|
||||
} else {
|
||||
for (const EventNode* parent : node->GetParents()) {
|
||||
nodes.push(parent);
|
||||
seen.insert(parent);
|
||||
}
|
||||
}
|
||||
while (!nodes.empty()) {
|
||||
const EventNode* node = nodes.front();
|
||||
nodes.pop();
|
||||
if (comparator(node)) return node;
|
||||
for (const EventNode* parent : node->GetParents()) {
|
||||
if (seen.contains(parent)) continue;
|
||||
nodes.push(parent);
|
||||
seen.insert(parent);
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Returns true if none of its ancestors is a root event.
|
||||
bool IsTopRoot(const EventNode* event) {
|
||||
// If it is already grouped, it is not a top root.
|
||||
if (event->GetGroupId().has_value()) return false;
|
||||
for (EventNode* cur = event->GetParent(); cur != nullptr;
|
||||
cur = cur->GetParent()) {
|
||||
if (cur->IsRoot()) return false;
|
||||
}
|
||||
return true;
|
||||
const EventNode* root_parent = FindParentWithComparator(
|
||||
[](const EventNode* node) { return node->IsRoot(); }, event,
|
||||
/*include_self=*/false);
|
||||
return root_parent == nullptr;
|
||||
}
|
||||
|
||||
void SortEventList(EventList* event_list) {
|
||||
@ -347,10 +380,20 @@ EventNode::EventNode(const EventNode& event_node)
|
||||
event_node.raw_event_) {}
|
||||
|
||||
absl::optional<XStatVisitor> EventNode::GetContextStat(int64 stat_type) const {
|
||||
for (const EventNode* node = this; node != nullptr; node = node->parent_) {
|
||||
std::queue<const EventNode*> nodes;
|
||||
absl::flat_hash_set<const EventNode*> seen = {this};
|
||||
nodes.push(this);
|
||||
while (!nodes.empty()) {
|
||||
const EventNode* node = nodes.front();
|
||||
nodes.pop();
|
||||
if (absl::optional<XStatVisitor> stat = node->visitor_.GetStat(stat_type)) {
|
||||
return stat;
|
||||
}
|
||||
for (const EventNode* parent : node->GetParents()) {
|
||||
if (seen.contains(parent)) continue;
|
||||
nodes.push(parent);
|
||||
seen.insert(parent);
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
@ -383,7 +426,7 @@ void EventNode::PropagateGroupId(int64 group_id) {
|
||||
// their nodes are added as child both in ConnectIntraThread and
|
||||
// ConnectInterThread).
|
||||
if (child->GetGroupId()) continue;
|
||||
child->PropagateGroupId(*group_id_);
|
||||
child->PropagateGroupId(group_id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -406,15 +449,11 @@ bool EventNode::IsEager() {
|
||||
}
|
||||
|
||||
const EventNode* EventNode::FindParent(int64 event_type) const {
|
||||
absl::flat_hash_set<const EventNode*> seen;
|
||||
const EventNode* node = this;
|
||||
while (node) {
|
||||
if (seen.contains(node)) break;
|
||||
if (node->GetEventVisitor().Type() == event_type) return node;
|
||||
seen.insert(node);
|
||||
node = node->GetParent();
|
||||
}
|
||||
return nullptr;
|
||||
return FindParentWithComparator(
|
||||
[event_type](const EventNode* node) {
|
||||
return node->GetEventVisitor().Type() == event_type;
|
||||
},
|
||||
this, /*include_self=*/true);
|
||||
}
|
||||
|
||||
bool EventNode::StartsBefore(const EventNode& other) const {
|
||||
@ -449,7 +488,7 @@ void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
|
||||
}
|
||||
parent_nodes.push_back(cur_node.get());
|
||||
// event_node_map_ keeps cur_node alive.
|
||||
event_node_map_[GetEventType(is_host_plane, visitor, event)].push_back(
|
||||
event_node_map_[GetEventType(is_host_plane, *cur_node)].push_back(
|
||||
std::move(cur_node));
|
||||
}
|
||||
}
|
||||
@ -513,10 +552,31 @@ void EventForest::ProcessLegacyRootEvents(
|
||||
}
|
||||
|
||||
void EventForest::CreateEventGroup() {
|
||||
// Handle inference batching profiles.
|
||||
if (event_node_map_.contains(HostEventType::kProcessBatch)) {
|
||||
for (const auto& process_batch_node :
|
||||
event_node_map_[HostEventType::kProcessBatch]) {
|
||||
ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
|
||||
process_batch_node.get(), &group_metadata_map_);
|
||||
}
|
||||
HostEventType request_event_type =
|
||||
event_node_map_.contains(HostEventType::kBatchingSession)
|
||||
? HostEventType::kBatchingSession
|
||||
: HostEventType::kSessionRun;
|
||||
if (auto request_events =
|
||||
gtl::FindOrNull(event_node_map_, request_event_type)) {
|
||||
for (const auto& request_event : *request_events) {
|
||||
ProcessRootEvent(next_group_id_++, /*set_step_name=*/false,
|
||||
request_event.get(), &group_metadata_map_);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Create a group for each TF loop iteration in non-JAX profiles.
|
||||
if (!HasJaxEvent(event_node_map_) && !tf_loop_root_events_.empty()) {
|
||||
for (EventNode* root_event : tf_loop_root_events_) {
|
||||
ProcessRootEvent(next_group_id_++, root_event, &group_metadata_map_);
|
||||
ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
|
||||
&group_metadata_map_);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -527,7 +587,8 @@ void EventForest::CreateEventGroup() {
|
||||
if (IsTopRoot(root_event) &&
|
||||
(!HasJaxEvent(event_node_map_) ||
|
||||
!IsLegacyRootEvent(root_event->GetEventVisitor()))) {
|
||||
ProcessRootEvent(next_group_id_++, root_event, &group_metadata_map_);
|
||||
ProcessRootEvent(next_group_id_++, /*set_step_name=*/true, root_event,
|
||||
&group_metadata_map_);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -697,7 +758,10 @@ void EventForest::ProcessTfDataEvents() {
|
||||
absl::optional<XStatVisitor> element_id =
|
||||
consume_event->GetEventVisitor().GetStat(StatType::kElementId);
|
||||
if (!element_id.has_value()) continue;
|
||||
EventNode* consume_iterator = consume_event->GetParent();
|
||||
if (consume_event->GetParents().empty()) continue;
|
||||
// consume_event is nested by consumer_iterator and does not have other
|
||||
// parents.
|
||||
EventNode* consume_iterator = consume_event->GetParents().at(0);
|
||||
if (!consume_iterator ||
|
||||
!IsDatasetOp(
|
||||
ParseTfOpFullname(consume_iterator->GetEventVisitor().Name()))) {
|
||||
|
||||
@ -58,13 +58,13 @@ class EventNode {
|
||||
|
||||
EventNode(const EventNode& event_node);
|
||||
|
||||
EventNode* GetParent() const { return parent_; }
|
||||
const std::vector<EventNode*>& GetParents() const { return parents_; }
|
||||
|
||||
const std::vector<EventNode*>& GetChildren() const { return children_; }
|
||||
|
||||
void AddChild(EventNode* child) {
|
||||
children_.push_back(child);
|
||||
child->parent_ = this;
|
||||
child->parents_.push_back(this);
|
||||
}
|
||||
|
||||
absl::optional<int64> GetGroupId() const { return group_id_; }
|
||||
@ -113,7 +113,7 @@ class EventNode {
|
||||
XEventVisitor visitor_;
|
||||
XLine* raw_line_;
|
||||
XEvent* raw_event_;
|
||||
EventNode* parent_ = nullptr;
|
||||
std::vector<EventNode*> parents_;
|
||||
std::vector<EventNode*> children_;
|
||||
absl::optional<int64> group_id_;
|
||||
absl::optional<ContextInfo> producer_context_;
|
||||
@ -136,7 +136,7 @@ using GroupMetadataMap = absl::flat_hash_map<int64 /*group_id*/, GroupMetadata>;
|
||||
using EventList = std::vector<EventNode*>;
|
||||
|
||||
struct ContextGroup {
|
||||
EventNode* producer = nullptr;
|
||||
std::vector<EventNode*> producers;
|
||||
std::vector<EventNode*> consumers;
|
||||
};
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/connected_traceme.h"
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
#include "tensorflow/core/profiler/utils/tf_xplane_visitor.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_builder.h"
|
||||
@ -563,6 +564,39 @@ TEST(GroupEventsTest, WorkerTest) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST(GroupEventsTest, BatchingSessionTest) {
|
||||
constexpr absl::string_view kSchedule = "Schedule";
|
||||
constexpr int64 kBatchContextType =
|
||||
static_cast<int64>(ContextType::kSharedBatchScheduler);
|
||||
constexpr int64 kBatchContextId = 123;
|
||||
|
||||
XSpace raw_space;
|
||||
XPlane* raw_plane = raw_space.add_planes();
|
||||
XPlaneBuilder plane(raw_plane);
|
||||
plane.ReserveLines(1);
|
||||
auto request_thread = plane.GetOrCreateLine(0);
|
||||
// First request.
|
||||
CreateXEvent(&plane, &request_thread, HostEventType::kBatchingSession, 0,
|
||||
100);
|
||||
CreateXEvent(&plane, &request_thread, kSchedule, 0, 100,
|
||||
{{StatType::kProducerType, kBatchContextType},
|
||||
{StatType::kProducerId, kBatchContextId}});
|
||||
// Second request.
|
||||
CreateXEvent(&plane, &request_thread, HostEventType::kBatchingSession, 200,
|
||||
100);
|
||||
CreateXEvent(&plane, &request_thread, kSchedule, 200, 100,
|
||||
{{StatType::kProducerType, kBatchContextType},
|
||||
{StatType::kProducerId, kBatchContextId}});
|
||||
auto batch_thread = plane.GetOrCreateLine(0);
|
||||
CreateXEvent(&plane, &batch_thread, HostEventType::kProcessBatch, 200, 100,
|
||||
{{StatType::kConsumerType, kBatchContextType},
|
||||
{StatType::kConsumerId, kBatchContextId}});
|
||||
|
||||
GroupMetadataMap group_metadata_map;
|
||||
GroupTfEvents(&raw_space, &group_metadata_map);
|
||||
EXPECT_EQ(group_metadata_map.size(), 3);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
||||
@ -108,6 +108,9 @@ const HostEventTypeMap& GetHostEventTypeMap() {
|
||||
{"MapAndBatchConsume", kMapAndBatchConsume},
|
||||
{"ParseExampleProduce", kParseExampleProduce},
|
||||
{"ParseExampleConsume", kParseExampleConsume},
|
||||
// Batching related.
|
||||
{"BatchingSession", kBatchingSession},
|
||||
{"ProcessBatch", kProcessBatch},
|
||||
// JAX related.
|
||||
{"LocalExecutable::ExecuteOnLocalDevices", kExecuteOnLocalDevices},
|
||||
// GPU related.
|
||||
|
||||
@ -99,6 +99,9 @@ enum HostEventType {
|
||||
kMapAndBatchConsume,
|
||||
kParseExampleProduce,
|
||||
kParseExampleConsume,
|
||||
// Batching related.
|
||||
kBatchingSession,
|
||||
kProcessBatch,
|
||||
// JAX related.
|
||||
kExecuteOnLocalDevices,
|
||||
// GPU related.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user