Propagate the eagerness of a GPU kernel from host to device.
PiperOrigin-RevId: 306333703 Change-Id: I95cc15e3d6bc2fde45bc9f07e23ba37be1dbc91e
This commit is contained in:
parent
04a426a832
commit
e1d83cf241
@ -50,6 +50,7 @@ void CreateStatMetadata(XPlane* plane) {
|
||||
XPlaneBuilder builder(plane);
|
||||
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
|
||||
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
|
||||
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
|
||||
}
|
||||
|
||||
// Returns event type if it is a KernelLaunch or KernelExecute event.
|
||||
@ -177,10 +178,26 @@ void EventNode::AddStepName(absl::string_view step_name) {
|
||||
step_name, event_);
|
||||
}
|
||||
|
||||
void EventNode::SetIsEager(bool is_eager) {
|
||||
AddOrUpdateIntStat(*visitor_->GetStatMetadataId(StatType::kIsEager),
|
||||
is_eager ? 1 : 0, event_);
|
||||
}
|
||||
|
||||
bool EventNode::IsNestedIn(EventNode* parent) {
|
||||
return parent && IsNested(GetEvent(), parent->GetEvent());
|
||||
}
|
||||
|
||||
EventNode* EventNode::FindParent(int64 event_type) {
|
||||
if (parent_) {
|
||||
if (GetEventType(parent_->GetPlaneVisitor(), parent_->GetEvent()) ==
|
||||
event_type) {
|
||||
return parent_;
|
||||
}
|
||||
return parent_->FindParent(event_type);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
|
||||
XPlane* plane) {
|
||||
for (auto& line : *plane->mutable_lines()) {
|
||||
@ -273,6 +290,19 @@ void EventForest::CreateEventGroup(
|
||||
}
|
||||
}
|
||||
|
||||
void EventForest::MarkEagerlyExecutedKernels() {
|
||||
auto kernel_execute_event_node_list =
|
||||
gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
|
||||
if (!kernel_execute_event_node_list) return;
|
||||
for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
|
||||
// A kernel is eagerly executed if its trace context does not include the
|
||||
// TF executor.
|
||||
bool is_eager = kernel_execute_event_node->FindParent(
|
||||
HostEventType::kExecutorStateProcess) == nullptr;
|
||||
kernel_execute_event_node->SetIsEager(is_eager);
|
||||
}
|
||||
}
|
||||
|
||||
void EventForest::CreateVirtualEventsForHostTrainingLoop() {
|
||||
VirtualEventNodeMap virtual_event_node_map;
|
||||
auto executor_event_node_list =
|
||||
@ -351,6 +381,7 @@ EventForest::EventForest(
|
||||
CreateVirtualEventsForAsyncExecutor();
|
||||
}
|
||||
CreateEventGroup(root_event_types);
|
||||
MarkEagerlyExecutedKernels();
|
||||
}
|
||||
|
||||
std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
|
||||
|
@ -71,8 +71,13 @@ class EventNode {
|
||||
|
||||
void AddStepName(absl::string_view step_name);
|
||||
|
||||
void SetIsEager(bool is_eager);
|
||||
|
||||
bool IsNestedIn(EventNode* parent);
|
||||
|
||||
// Returns the closest parent of the given event type.
|
||||
EventNode* FindParent(int64 event_type);
|
||||
|
||||
private:
|
||||
const XPlaneVisitor* visitor_;
|
||||
XEvent* event_;
|
||||
@ -120,6 +125,9 @@ class EventForest {
|
||||
void CreateEventGroup(
|
||||
const std::vector<int64 /*EventType*/>& root_event_types);
|
||||
|
||||
// Sets the is_eager stat to true for the eagerly executed kernel events.
|
||||
void MarkEagerlyExecutedKernels();
|
||||
|
||||
// Create virtual events of HostEventType::kHostTrainingLoopIteration and
|
||||
// event nodes for them. A virtual event is created for each iteration of the
|
||||
// host training loop and connected to the
|
||||
|
@ -57,7 +57,7 @@ TEST(GroupEventsTest, GroupGpuTraceTest) {
|
||||
EventGroupNameMap event_group_name_map;
|
||||
GroupTfEvents(&space, &event_group_name_map);
|
||||
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 2);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
|
||||
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||
device_plane->lines(0).events(0).stats(1)),
|
||||
StatType::kGroupId);
|
||||
@ -89,7 +89,7 @@ TEST(GroupEventsTest, GroupHostTrainingLoopTest) {
|
||||
EventGroupNameMap event_group_name_map;
|
||||
GroupTfEvents(&space, &event_group_name_map);
|
||||
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 2);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
|
||||
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||
device_plane->lines(0).events(0).stats(1)),
|
||||
StatType::kGroupId);
|
||||
@ -143,6 +143,75 @@ TEST(GroupEventsTest, GroupFunctionalOp) {
|
||||
});
|
||||
}
|
||||
|
||||
TEST(GroupEventsTest, EagerOpTest) {
|
||||
XSpace space;
|
||||
XPlaneBuilder host_plane_builder(space.add_planes());
|
||||
host_plane_builder.SetName(kHostThreads);
|
||||
host_plane_builder.ReserveLines(1);
|
||||
|
||||
auto main_thread = host_plane_builder.GetOrCreateLine(0);
|
||||
// Eagerly scheduled GPU kernel.
|
||||
CreateXEvent(&host_plane_builder, &main_thread, "matmul", 10, 100,
|
||||
{{StatType::kCorrelationId, 100}});
|
||||
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
|
||||
110, 200, {{StatType::kStepId, 0}});
|
||||
|
||||
XPlane* device_plane = space.add_planes();
|
||||
XPlaneBuilder device_plane_builder(device_plane);
|
||||
device_plane_builder.ReserveLines(1);
|
||||
|
||||
auto stream = device_plane_builder.GetOrCreateLine(0);
|
||||
// Eagerly executed GPU kernel.
|
||||
CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
|
||||
{{StatType::kCorrelationId, 100}});
|
||||
|
||||
GroupTfEvents(&space, /*event_group_name_map=*/nullptr);
|
||||
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 2);
|
||||
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||
device_plane->lines(0).events(0).stats(1)),
|
||||
StatType::kIsEager);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats(1).int64_value(), 1);
|
||||
}
|
||||
|
||||
TEST(GroupEventsTest, FunctionOpTest) {
|
||||
XSpace space;
|
||||
XPlaneBuilder host_plane_builder(space.add_planes());
|
||||
host_plane_builder.SetName(kHostThreads);
|
||||
host_plane_builder.ReserveLines(2);
|
||||
|
||||
auto main_thread = host_plane_builder.GetOrCreateLine(0);
|
||||
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext,
|
||||
0, 100, {{StatType::kStepNum, 123}});
|
||||
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
|
||||
10, 90, {{StatType::kStepId, 0}});
|
||||
|
||||
auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1);
|
||||
CreateXEvent(&host_plane_builder, &tf_executor_thread,
|
||||
HostEventType::kExecutorStateProcess, 20, 80,
|
||||
{{StatType::kStepId, 0}});
|
||||
// GPU kernel scheduled inside tf.function.
|
||||
CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70,
|
||||
{{StatType::kCorrelationId, 100}});
|
||||
|
||||
XPlane* device_plane = space.add_planes();
|
||||
XPlaneBuilder device_plane_builder(device_plane);
|
||||
device_plane_builder.ReserveLines(1);
|
||||
|
||||
auto stream = device_plane_builder.GetOrCreateLine(0);
|
||||
// GPU kernel executed as part of tf.function.
|
||||
CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
|
||||
{{StatType::kCorrelationId, 100}});
|
||||
|
||||
GroupTfEvents(&space, /*event_group_name_map=*/nullptr);
|
||||
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
|
||||
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||
device_plane->lines(0).events(0).stats(2)),
|
||||
StatType::kIsEager);
|
||||
EXPECT_EQ(device_plane->lines(0).events(0).stats(2).int64_value(), 0);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
@ -146,6 +146,7 @@ const StatTypeMap& GetStatTypeMap() {
|
||||
{"hlo_op", kHloOp},
|
||||
{"hlo_module", kHloModule},
|
||||
{"equation", kEquation},
|
||||
{"is_eager", kIsEager},
|
||||
// Performance counter related.
|
||||
{"Raw Value", kRawValue},
|
||||
{"Scaled Value", kScaledValue},
|
||||
|
@ -139,6 +139,7 @@ enum StatType {
|
||||
kHloOp,
|
||||
kHloModule,
|
||||
kEquation,
|
||||
kIsEager,
|
||||
// Performance counter related.
|
||||
kRawValue,
|
||||
kScaledValue,
|
||||
|
Loading…
Reference in New Issue
Block a user