Propagate the eagerness of a GPU kernel from host to device.
PiperOrigin-RevId: 306333703 Change-Id: I95cc15e3d6bc2fde45bc9f07e23ba37be1dbc91e
This commit is contained in:
parent
04a426a832
commit
e1d83cf241
@ -50,6 +50,7 @@ void CreateStatMetadata(XPlane* plane) {
|
|||||||
XPlaneBuilder builder(plane);
|
XPlaneBuilder builder(plane);
|
||||||
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
|
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kGroupId));
|
||||||
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
|
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kStepName));
|
||||||
|
builder.GetOrCreateStatMetadata(GetStatTypeStr(StatType::kIsEager));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns event type if it is a KernelLaunch or KernelExecute event.
|
// Returns event type if it is a KernelLaunch or KernelExecute event.
|
||||||
@ -177,10 +178,26 @@ void EventNode::AddStepName(absl::string_view step_name) {
|
|||||||
step_name, event_);
|
step_name, event_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EventNode::SetIsEager(bool is_eager) {
|
||||||
|
AddOrUpdateIntStat(*visitor_->GetStatMetadataId(StatType::kIsEager),
|
||||||
|
is_eager ? 1 : 0, event_);
|
||||||
|
}
|
||||||
|
|
||||||
bool EventNode::IsNestedIn(EventNode* parent) {
|
bool EventNode::IsNestedIn(EventNode* parent) {
|
||||||
return parent && IsNested(GetEvent(), parent->GetEvent());
|
return parent && IsNested(GetEvent(), parent->GetEvent());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EventNode* EventNode::FindParent(int64 event_type) {
|
||||||
|
if (parent_) {
|
||||||
|
if (GetEventType(parent_->GetPlaneVisitor(), parent_->GetEvent()) ==
|
||||||
|
event_type) {
|
||||||
|
return parent_;
|
||||||
|
}
|
||||||
|
return parent_->FindParent(event_type);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
|
void EventForest::ConnectIntraThread(const XPlaneVisitor& visitor,
|
||||||
XPlane* plane) {
|
XPlane* plane) {
|
||||||
for (auto& line : *plane->mutable_lines()) {
|
for (auto& line : *plane->mutable_lines()) {
|
||||||
@ -273,6 +290,19 @@ void EventForest::CreateEventGroup(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void EventForest::MarkEagerlyExecutedKernels() {
|
||||||
|
auto kernel_execute_event_node_list =
|
||||||
|
gtl::FindOrNull(event_node_map_, HostEventType::kKernelExecute);
|
||||||
|
if (!kernel_execute_event_node_list) return;
|
||||||
|
for (auto& kernel_execute_event_node : *kernel_execute_event_node_list) {
|
||||||
|
// A kernel is eagerly executed if its trace context does not include the
|
||||||
|
// TF executor.
|
||||||
|
bool is_eager = kernel_execute_event_node->FindParent(
|
||||||
|
HostEventType::kExecutorStateProcess) == nullptr;
|
||||||
|
kernel_execute_event_node->SetIsEager(is_eager);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void EventForest::CreateVirtualEventsForHostTrainingLoop() {
|
void EventForest::CreateVirtualEventsForHostTrainingLoop() {
|
||||||
VirtualEventNodeMap virtual_event_node_map;
|
VirtualEventNodeMap virtual_event_node_map;
|
||||||
auto executor_event_node_list =
|
auto executor_event_node_list =
|
||||||
@ -351,6 +381,7 @@ EventForest::EventForest(
|
|||||||
CreateVirtualEventsForAsyncExecutor();
|
CreateVirtualEventsForAsyncExecutor();
|
||||||
}
|
}
|
||||||
CreateEventGroup(root_event_types);
|
CreateEventGroup(root_event_types);
|
||||||
|
MarkEagerlyExecutedKernels();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
|
std::vector<InterThreadConnectInfo> CreateInterThreadConnectInfoList() {
|
||||||
|
@ -71,8 +71,13 @@ class EventNode {
|
|||||||
|
|
||||||
void AddStepName(absl::string_view step_name);
|
void AddStepName(absl::string_view step_name);
|
||||||
|
|
||||||
|
void SetIsEager(bool is_eager);
|
||||||
|
|
||||||
bool IsNestedIn(EventNode* parent);
|
bool IsNestedIn(EventNode* parent);
|
||||||
|
|
||||||
|
// Returns the closest parent of the given event type.
|
||||||
|
EventNode* FindParent(int64 event_type);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const XPlaneVisitor* visitor_;
|
const XPlaneVisitor* visitor_;
|
||||||
XEvent* event_;
|
XEvent* event_;
|
||||||
@ -120,6 +125,9 @@ class EventForest {
|
|||||||
void CreateEventGroup(
|
void CreateEventGroup(
|
||||||
const std::vector<int64 /*EventType*/>& root_event_types);
|
const std::vector<int64 /*EventType*/>& root_event_types);
|
||||||
|
|
||||||
|
// Sets the is_eager stat to true for the eagerly executed kernel events.
|
||||||
|
void MarkEagerlyExecutedKernels();
|
||||||
|
|
||||||
// Create virtual events of HostEventType::kHostTrainingLoopIteration and
|
// Create virtual events of HostEventType::kHostTrainingLoopIteration and
|
||||||
// event nodes for them. A virtual event is created for each iteration of the
|
// event nodes for them. A virtual event is created for each iteration of the
|
||||||
// host training loop and connected to the
|
// host training loop and connected to the
|
||||||
|
@ -57,7 +57,7 @@ TEST(GroupEventsTest, GroupGpuTraceTest) {
|
|||||||
EventGroupNameMap event_group_name_map;
|
EventGroupNameMap event_group_name_map;
|
||||||
GroupTfEvents(&space, &event_group_name_map);
|
GroupTfEvents(&space, &event_group_name_map);
|
||||||
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 2);
|
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
|
||||||
EXPECT_EQ(device_plane_visitor.GetStatType(
|
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||||
device_plane->lines(0).events(0).stats(1)),
|
device_plane->lines(0).events(0).stats(1)),
|
||||||
StatType::kGroupId);
|
StatType::kGroupId);
|
||||||
@ -89,7 +89,7 @@ TEST(GroupEventsTest, GroupHostTrainingLoopTest) {
|
|||||||
EventGroupNameMap event_group_name_map;
|
EventGroupNameMap event_group_name_map;
|
||||||
GroupTfEvents(&space, &event_group_name_map);
|
GroupTfEvents(&space, &event_group_name_map);
|
||||||
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||||
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 2);
|
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
|
||||||
EXPECT_EQ(device_plane_visitor.GetStatType(
|
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||||
device_plane->lines(0).events(0).stats(1)),
|
device_plane->lines(0).events(0).stats(1)),
|
||||||
StatType::kGroupId);
|
StatType::kGroupId);
|
||||||
@ -143,6 +143,75 @@ TEST(GroupEventsTest, GroupFunctionalOp) {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(GroupEventsTest, EagerOpTest) {
|
||||||
|
XSpace space;
|
||||||
|
XPlaneBuilder host_plane_builder(space.add_planes());
|
||||||
|
host_plane_builder.SetName(kHostThreads);
|
||||||
|
host_plane_builder.ReserveLines(1);
|
||||||
|
|
||||||
|
auto main_thread = host_plane_builder.GetOrCreateLine(0);
|
||||||
|
// Eagerly scheduled GPU kernel.
|
||||||
|
CreateXEvent(&host_plane_builder, &main_thread, "matmul", 10, 100,
|
||||||
|
{{StatType::kCorrelationId, 100}});
|
||||||
|
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
|
||||||
|
110, 200, {{StatType::kStepId, 0}});
|
||||||
|
|
||||||
|
XPlane* device_plane = space.add_planes();
|
||||||
|
XPlaneBuilder device_plane_builder(device_plane);
|
||||||
|
device_plane_builder.ReserveLines(1);
|
||||||
|
|
||||||
|
auto stream = device_plane_builder.GetOrCreateLine(0);
|
||||||
|
// Eagerly executed GPU kernel.
|
||||||
|
CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
|
||||||
|
{{StatType::kCorrelationId, 100}});
|
||||||
|
|
||||||
|
GroupTfEvents(&space, /*event_group_name_map=*/nullptr);
|
||||||
|
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||||
|
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 2);
|
||||||
|
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||||
|
device_plane->lines(0).events(0).stats(1)),
|
||||||
|
StatType::kIsEager);
|
||||||
|
EXPECT_EQ(device_plane->lines(0).events(0).stats(1).int64_value(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(GroupEventsTest, FunctionOpTest) {
|
||||||
|
XSpace space;
|
||||||
|
XPlaneBuilder host_plane_builder(space.add_planes());
|
||||||
|
host_plane_builder.SetName(kHostThreads);
|
||||||
|
host_plane_builder.ReserveLines(2);
|
||||||
|
|
||||||
|
auto main_thread = host_plane_builder.GetOrCreateLine(0);
|
||||||
|
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kTraceContext,
|
||||||
|
0, 100, {{StatType::kStepNum, 123}});
|
||||||
|
CreateXEvent(&host_plane_builder, &main_thread, HostEventType::kFunctionRun,
|
||||||
|
10, 90, {{StatType::kStepId, 0}});
|
||||||
|
|
||||||
|
auto tf_executor_thread = host_plane_builder.GetOrCreateLine(1);
|
||||||
|
CreateXEvent(&host_plane_builder, &tf_executor_thread,
|
||||||
|
HostEventType::kExecutorStateProcess, 20, 80,
|
||||||
|
{{StatType::kStepId, 0}});
|
||||||
|
// GPU kernel scheduled inside tf.function.
|
||||||
|
CreateXEvent(&host_plane_builder, &tf_executor_thread, "matmul", 30, 70,
|
||||||
|
{{StatType::kCorrelationId, 100}});
|
||||||
|
|
||||||
|
XPlane* device_plane = space.add_planes();
|
||||||
|
XPlaneBuilder device_plane_builder(device_plane);
|
||||||
|
device_plane_builder.ReserveLines(1);
|
||||||
|
|
||||||
|
auto stream = device_plane_builder.GetOrCreateLine(0);
|
||||||
|
// GPU kernel executed as part of tf.function.
|
||||||
|
CreateXEvent(&device_plane_builder, &stream, "matmul", 200, 300,
|
||||||
|
{{StatType::kCorrelationId, 100}});
|
||||||
|
|
||||||
|
GroupTfEvents(&space, /*event_group_name_map=*/nullptr);
|
||||||
|
XPlaneVisitor device_plane_visitor = CreateTfXPlaneVisitor(device_plane);
|
||||||
|
EXPECT_EQ(device_plane->lines(0).events(0).stats_size(), 3);
|
||||||
|
EXPECT_EQ(device_plane_visitor.GetStatType(
|
||||||
|
device_plane->lines(0).events(0).stats(2)),
|
||||||
|
StatType::kIsEager);
|
||||||
|
EXPECT_EQ(device_plane->lines(0).events(0).stats(2).int64_value(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace profiler
|
} // namespace profiler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -146,6 +146,7 @@ const StatTypeMap& GetStatTypeMap() {
|
|||||||
{"hlo_op", kHloOp},
|
{"hlo_op", kHloOp},
|
||||||
{"hlo_module", kHloModule},
|
{"hlo_module", kHloModule},
|
||||||
{"equation", kEquation},
|
{"equation", kEquation},
|
||||||
|
{"is_eager", kIsEager},
|
||||||
// Performance counter related.
|
// Performance counter related.
|
||||||
{"Raw Value", kRawValue},
|
{"Raw Value", kRawValue},
|
||||||
{"Scaled Value", kScaledValue},
|
{"Scaled Value", kScaledValue},
|
||||||
|
@ -139,6 +139,7 @@ enum StatType {
|
|||||||
kHloOp,
|
kHloOp,
|
||||||
kHloModule,
|
kHloModule,
|
||||||
kEquation,
|
kEquation,
|
||||||
|
kIsEager,
|
||||||
// Performance counter related.
|
// Performance counter related.
|
||||||
kRawValue,
|
kRawValue,
|
||||||
kScaledValue,
|
kScaledValue,
|
||||||
|
Loading…
Reference in New Issue
Block a user