[xprof:gpu] Optimize XPlane to KernelStatsDb converter by switching intermediate data structure from a vector to flat_hash_map.
xplane_to_kernel_stats_db.h - Do not provide a direct conversion from XPlane to KernelStatsDb, which is a many to one conversion can be parallelized while being aggregated into a faster data structure (a hash map in this case). PiperOrigin-RevId: 322860348 Change-Id: Ibde4fc7ae4c5222d059f0eb1f77f57aa2878a58a
This commit is contained in:
parent
a6e66d50a4
commit
2587c2a1f2
@ -421,6 +421,7 @@ cc_library(
|
||||
"//tensorflow/core/profiler/utils:trace_utils",
|
||||
"//tensorflow/core/profiler/utils:xplane_schema",
|
||||
"//tensorflow/core/profiler/utils:xplane_visitor",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
@ -437,15 +438,12 @@ tf_cc_test(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc",
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
"//tensorflow/core/profiler/utils:tf_xplane_visitor",
|
||||
"//tensorflow/core/profiler/utils:kernel_stats_utils",
|
||||
"//tensorflow/core/profiler/utils:xplane_builder",
|
||||
"//tensorflow/core/profiler/utils:xplane_schema",
|
||||
"//tensorflow/core/profiler/utils:xplane_test_utils",
|
||||
"//tensorflow/core/profiler/utils:xplane_utils",
|
||||
"//tensorflow/core/profiler/utils:xplane_visitor",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -33,11 +34,11 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb(
|
||||
void ConvertDeviceTraceXPlaneToKernelReports(
|
||||
const XPlane& device_trace,
|
||||
const std::function<void(const XEventVisitor&, KernelReport*)>&
|
||||
on_kernel_fn) {
|
||||
KernelStatsDb result;
|
||||
on_kernel_fn,
|
||||
KernelReportMap* reports) {
|
||||
XPlaneVisitor plane = CreateTfXPlaneVisitor(&device_trace);
|
||||
plane.ForEachLine([&](const XLineVisitor& line) {
|
||||
if (IsDerivedThreadId(line.Id())) {
|
||||
@ -92,12 +93,15 @@ KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb(
|
||||
}
|
||||
|
||||
if (kernel.total_duration_ns()) {
|
||||
*result.add_reports() = kernel;
|
||||
KernelReportValue value;
|
||||
value.total_duration_ns = event.DurationNs();
|
||||
value.min_duration_ns = event.DurationNs();
|
||||
value.max_duration_ns = event.DurationNs();
|
||||
value.occurrences = 1;
|
||||
InsertOrUpdateKernelReport(kernel, value, reports);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
@ -18,17 +18,20 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
#include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_visitor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
KernelStatsDb ConvertDeviceTraceXPlaneToKernelStatsDb(
|
||||
void ConvertDeviceTraceXPlaneToKernelReports(
|
||||
const XPlane& device_trace,
|
||||
const std::function<void(const XEventVisitor&, KernelReport*)>&
|
||||
on_kernel_fn);
|
||||
on_kernel_fn,
|
||||
KernelReportMap* reports);
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
#include "tensorflow/core/profiler/utils/kernel_stats_utils.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_builder.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_schema.h"
|
||||
#include "tensorflow/core/profiler/utils/xplane_test_utils.h"
|
||||
@ -37,7 +38,7 @@ TEST(ConvertXplaneToKernelStats, MultiKernels) {
|
||||
device_trace_builder.GetOrCreateLine(0);
|
||||
|
||||
XLineBuilder line_builder = device_trace_builder.GetOrCreateLine(0);
|
||||
CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_0",
|
||||
CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_shortest",
|
||||
/*offset_ps=*/10000, /*duration_ps=*/1000,
|
||||
{{StatType::kLevel0, "mul_786"},
|
||||
{StatType::kKernelDetails, R"MULTI(registers_per_thread:16
|
||||
@ -51,7 +52,7 @@ block_y:1
|
||||
block_z:1)MULTI"},
|
||||
{StatType::kEquation, ""}});
|
||||
|
||||
CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_1",
|
||||
CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_middle",
|
||||
/*offset_ps=*/20000, /*duration_ps=*/2000,
|
||||
{{StatType::kLevel0, "Conv2D"},
|
||||
{StatType::kKernelDetails, R"MULTI(registers_per_thread:32
|
||||
@ -79,58 +80,68 @@ block_x:64
|
||||
block_y:1
|
||||
block_z:1)MULTI"},
|
||||
{StatType::kEquation, ""}});
|
||||
KernelStatsDb kernel_stats =
|
||||
ConvertDeviceTraceXPlaneToKernelStatsDb(*device_trace, {});
|
||||
|
||||
KernelReportMap reports;
|
||||
ConvertDeviceTraceXPlaneToKernelReports(*device_trace, {}, &reports);
|
||||
KernelStatsDb kernel_stats;
|
||||
CopyKernelReportsToDb(reports, &kernel_stats);
|
||||
SortKernelsByTotalDurationDesc(&kernel_stats);
|
||||
|
||||
EXPECT_EQ(kernel_stats.reports_size(), 3);
|
||||
|
||||
const auto& kernel0 = kernel_stats.reports().at(0);
|
||||
EXPECT_EQ(kernel0.name(), "kernel_name_0");
|
||||
EXPECT_EQ(kernel0.registers_per_thread(), 16);
|
||||
EXPECT_EQ(kernel0.static_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel0.dynamic_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel0.grid_dim().at(0), 1);
|
||||
EXPECT_EQ(kernel0.grid_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel0.grid_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel0.block_dim().at(0), 1);
|
||||
EXPECT_EQ(kernel0.block_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel0.block_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel0.total_duration_ns(), 1);
|
||||
EXPECT_FALSE(kernel0.is_kernel_using_tensor_core());
|
||||
EXPECT_FALSE(kernel0.is_op_tensor_core_eligible());
|
||||
EXPECT_EQ(kernel0.op_name(), "mul_786");
|
||||
{
|
||||
const auto& kernel = kernel_stats.reports().at(2);
|
||||
EXPECT_EQ(kernel.name(), "kernel_name_shortest");
|
||||
EXPECT_EQ(kernel.registers_per_thread(), 16);
|
||||
EXPECT_EQ(kernel.static_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel.dynamic_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel.grid_dim().at(0), 1);
|
||||
EXPECT_EQ(kernel.grid_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel.grid_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(0), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel.total_duration_ns(), 1);
|
||||
EXPECT_FALSE(kernel.is_kernel_using_tensor_core());
|
||||
EXPECT_FALSE(kernel.is_op_tensor_core_eligible());
|
||||
EXPECT_EQ(kernel.op_name(), "mul_786");
|
||||
}
|
||||
|
||||
const auto& kernel1 = kernel_stats.reports().at(1);
|
||||
EXPECT_EQ(kernel1.name(), "kernel_name_1");
|
||||
EXPECT_EQ(kernel1.registers_per_thread(), 32);
|
||||
EXPECT_EQ(kernel1.static_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel1.dynamic_shmem_bytes(), 16384);
|
||||
EXPECT_EQ(kernel1.grid_dim().at(0), 2);
|
||||
EXPECT_EQ(kernel1.grid_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel1.grid_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel1.block_dim().at(0), 32);
|
||||
EXPECT_EQ(kernel1.block_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel1.block_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel1.total_duration_ns(), 2);
|
||||
EXPECT_FALSE(kernel1.is_kernel_using_tensor_core());
|
||||
EXPECT_TRUE(kernel1.is_op_tensor_core_eligible());
|
||||
EXPECT_EQ(kernel1.op_name(), "Conv2D");
|
||||
{
|
||||
const auto& kernel = kernel_stats.reports().at(1);
|
||||
EXPECT_EQ(kernel.name(), "kernel_name_middle");
|
||||
EXPECT_EQ(kernel.registers_per_thread(), 32);
|
||||
EXPECT_EQ(kernel.static_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel.dynamic_shmem_bytes(), 16384);
|
||||
EXPECT_EQ(kernel.grid_dim().at(0), 2);
|
||||
EXPECT_EQ(kernel.grid_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel.grid_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(0), 32);
|
||||
EXPECT_EQ(kernel.block_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel.total_duration_ns(), 2);
|
||||
EXPECT_FALSE(kernel.is_kernel_using_tensor_core());
|
||||
EXPECT_TRUE(kernel.is_op_tensor_core_eligible());
|
||||
EXPECT_EQ(kernel.op_name(), "Conv2D");
|
||||
}
|
||||
|
||||
const auto& kernel2 = kernel_stats.reports().at(2);
|
||||
EXPECT_EQ(kernel2.name(), "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn");
|
||||
EXPECT_EQ(kernel2.registers_per_thread(), 32);
|
||||
EXPECT_EQ(kernel2.static_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel2.dynamic_shmem_bytes(), 16384);
|
||||
EXPECT_EQ(kernel2.grid_dim().at(0), 3);
|
||||
EXPECT_EQ(kernel2.grid_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel2.grid_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel2.block_dim().at(0), 64);
|
||||
EXPECT_EQ(kernel2.block_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel2.block_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel2.total_duration_ns(), 3);
|
||||
EXPECT_TRUE(kernel2.is_kernel_using_tensor_core());
|
||||
EXPECT_TRUE(kernel2.is_op_tensor_core_eligible());
|
||||
EXPECT_EQ(kernel2.op_name(), "Einsum_80");
|
||||
{
|
||||
const auto& kernel = kernel_stats.reports().at(0);
|
||||
EXPECT_EQ(kernel.name(), "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn");
|
||||
EXPECT_EQ(kernel.registers_per_thread(), 32);
|
||||
EXPECT_EQ(kernel.static_shmem_bytes(), 0);
|
||||
EXPECT_EQ(kernel.dynamic_shmem_bytes(), 16384);
|
||||
EXPECT_EQ(kernel.grid_dim().at(0), 3);
|
||||
EXPECT_EQ(kernel.grid_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel.grid_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(0), 64);
|
||||
EXPECT_EQ(kernel.block_dim().at(1), 1);
|
||||
EXPECT_EQ(kernel.block_dim().at(2), 1);
|
||||
EXPECT_EQ(kernel.total_duration_ns(), 3);
|
||||
EXPECT_TRUE(kernel.is_kernel_using_tensor_core());
|
||||
EXPECT_TRUE(kernel.is_op_tensor_core_eligible());
|
||||
EXPECT_EQ(kernel.op_name(), "Einsum_80");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/convert/op_metrics_db_combiner.h"
|
||||
#include "tensorflow/core/profiler/convert/step_events_to_steps_db.h"
|
||||
@ -154,7 +155,8 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
|
||||
op_stats.mutable_device_op_metrics_db());
|
||||
SetRunEnvironment(device_planes.size(), op_stats.mutable_run_environment());
|
||||
|
||||
std::vector<KernelReport> reports;
|
||||
KernelReportMap reports;
|
||||
// TODO(b/161942993) parallelize XPlane processing per thread.
|
||||
for (const XPlane* device_trace : device_planes) {
|
||||
if (config.contains(OP_METRICS_DB)) {
|
||||
if (!op_stats.has_perf_env()) {
|
||||
@ -171,16 +173,18 @@ OpStats ConvertXSpaceToOpStats(const XSpace& space,
|
||||
&step_events);
|
||||
}
|
||||
if (config.contains(KERNEL_STATS_DB)) {
|
||||
KernelStatsDb kernel_stats_db = ConvertDeviceTraceXPlaneToKernelStatsDb(
|
||||
*device_trace, /*on_kernel_fn=*/{});
|
||||
reports.insert(reports.begin(), kernel_stats_db.reports().begin(),
|
||||
kernel_stats_db.reports().end());
|
||||
ConvertDeviceTraceXPlaneToKernelReports(*device_trace,
|
||||
/*on_kernel_fn=*/{}, &reports);
|
||||
}
|
||||
}
|
||||
|
||||
// Combine into reports.
|
||||
if (config.contains(KERNEL_STATS_DB)) {
|
||||
GroupKernelReports(&reports, op_stats.mutable_kernel_stats_db());
|
||||
CopyKernelReportsToDb(reports, op_stats.mutable_kernel_stats_db());
|
||||
// TODO(b/161943499) Replace sort with a TopK algorithm.
|
||||
SortKernelsByTotalDurationDesc(op_stats.mutable_kernel_stats_db());
|
||||
}
|
||||
|
||||
bool has_device = !device_planes.empty();
|
||||
// Convert a host plane.
|
||||
if (host_plane && config.contains(OP_METRICS_DB)) {
|
||||
|
@ -401,6 +401,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "absl/strings/numbers.h"
|
||||
#include "absl/strings/str_split.h"
|
||||
@ -142,7 +143,7 @@ bool IsEinsumTensorCoreEligible(absl::string_view equation) {
|
||||
}
|
||||
|
||||
bool KernelReportLessThanComparator::operator()(const KernelReport& lhs,
|
||||
const KernelReport& rhs) {
|
||||
const KernelReport& rhs) const {
|
||||
// Disable formatting to keep vertical alignment for better readability,
|
||||
// and make it easier to reorder columns.
|
||||
// clang-format off
|
||||
@ -180,7 +181,7 @@ bool KernelReportLessThanComparator::operator()(const KernelReport& lhs,
|
||||
}
|
||||
|
||||
bool KernelReportEqualToComparator::operator()(const KernelReport& lhs,
|
||||
const KernelReport& rhs) {
|
||||
const KernelReport& rhs) const {
|
||||
// Disable formatting to keep vertical alignment for better readability,
|
||||
// and make it easier to reorder columns.
|
||||
// clang-format off
|
||||
@ -213,32 +214,37 @@ void SortKernelsByTotalDurationDesc(KernelStatsDb* kernel_stats_db) {
|
||||
});
|
||||
}
|
||||
|
||||
void GroupKernelReports(std::vector<KernelReport>* reports,
|
||||
KernelStatsDb* dst) {
|
||||
// Sort reports by grouping criteria.
|
||||
std::sort(reports->begin(), reports->end(), KernelReportLessThanComparator());
|
||||
void CopyKernelReportsToDb(const KernelReportMap& reports, KernelStatsDb* dst) {
|
||||
for (const auto& report_value : reports) {
|
||||
KernelReport* report = dst->add_reports();
|
||||
*report = report_value.first;
|
||||
// Set value using KernelReportValue.
|
||||
report->set_occurrences(report_value.second.occurrences);
|
||||
report->set_min_duration_ns(report_value.second.min_duration_ns);
|
||||
report->set_max_duration_ns(report_value.second.max_duration_ns);
|
||||
report->set_total_duration_ns(report_value.second.total_duration_ns);
|
||||
}
|
||||
}
|
||||
|
||||
// Group reports together.
|
||||
KernelReport* prev = nullptr;
|
||||
for (const KernelReport& report : *reports) {
|
||||
DCHECK_EQ(3, report.grid_dim_size());
|
||||
DCHECK_EQ(3, report.block_dim_size());
|
||||
if (prev != nullptr && KernelReportEqualToComparator()(*prev, report)) {
|
||||
// Previous element is identical to the one that we are adding, so
|
||||
// aggregate them.
|
||||
prev->set_occurrences(prev->occurrences() + 1);
|
||||
prev->set_max_duration_ns(
|
||||
std::max(prev->max_duration_ns(), report.max_duration_ns()));
|
||||
prev->set_min_duration_ns(
|
||||
std::min(prev->min_duration_ns(), report.min_duration_ns()));
|
||||
prev->set_total_duration_ns(prev->total_duration_ns() +
|
||||
report.total_duration_ns());
|
||||
} else {
|
||||
// Current element does not exist yet.
|
||||
prev = dst->add_reports();
|
||||
*prev = report;
|
||||
prev->set_occurrences(1);
|
||||
}
|
||||
void InsertOrUpdateKernelReport(const KernelReport& kernel,
|
||||
const KernelReportValue& value,
|
||||
KernelReportMap* dst) {
|
||||
KernelReportValue& element = (*dst)[kernel];
|
||||
if (element.occurrences == 0) {
|
||||
element = value;
|
||||
} else {
|
||||
element.total_duration_ns += value.total_duration_ns;
|
||||
element.min_duration_ns =
|
||||
std::min(element.min_duration_ns, value.min_duration_ns);
|
||||
element.max_duration_ns =
|
||||
std::max(element.max_duration_ns, value.max_duration_ns);
|
||||
element.occurrences += 1;
|
||||
}
|
||||
}
|
||||
|
||||
void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst) {
|
||||
for (auto& kernel_value : reports) {
|
||||
InsertOrUpdateKernelReport(kernel_value.first, kernel_value.second, dst);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -40,19 +41,70 @@ bool IsEinsumTensorCoreEligible(absl::string_view equation);
|
||||
|
||||
// Less than comparator for Kernel Reports.
|
||||
struct KernelReportLessThanComparator {
|
||||
bool operator()(const KernelReport& lhs, const KernelReport& rhs);
|
||||
bool operator()(const KernelReport& lhs, const KernelReport& rhs) const;
|
||||
};
|
||||
|
||||
// Equal to comparator for Kernel Reports.
|
||||
struct KernelReportEqualToComparator {
|
||||
bool operator()(const KernelReport& lhs, const KernelReport& rhs);
|
||||
bool operator()(const KernelReport& lhs, const KernelReport& rhs) const;
|
||||
};
|
||||
|
||||
// Sorts kernel reorts by total duration descendingly.
|
||||
void SortKernelsByTotalDurationDesc(KernelStatsDb* kernel_stats_db);
|
||||
|
||||
// Groups and aggregate common reports into destination KernelStatsDb.
|
||||
void GroupKernelReports(std::vector<KernelReport>* reports, KernelStatsDb* dst);
|
||||
struct KernelReportValue {
|
||||
uint64 total_duration_ns = 0;
|
||||
uint64 min_duration_ns = 0;
|
||||
uint64 max_duration_ns = 0;
|
||||
uint64 occurrences = 0;
|
||||
};
|
||||
|
||||
struct KernelKeyWrap {
|
||||
const KernelReport* key;
|
||||
template <typename H>
|
||||
friend H AbslHashValue(H h, KernelKeyWrap wrap) {
|
||||
// Kernel reports are grouped by these fields, hence they are used as
|
||||
// hashing criteria.
|
||||
// clang-format off
|
||||
return H::combine(
|
||||
std::move(h),
|
||||
wrap.key->is_kernel_using_tensor_core(),
|
||||
wrap.key->is_op_tensor_core_eligible(),
|
||||
wrap.key->block_dim(0),
|
||||
wrap.key->block_dim(1),
|
||||
wrap.key->block_dim(2),
|
||||
wrap.key->grid_dim(0),
|
||||
wrap.key->grid_dim(1),
|
||||
wrap.key->grid_dim(2),
|
||||
wrap.key->registers_per_thread(),
|
||||
wrap.key->static_shmem_bytes(),
|
||||
wrap.key->dynamic_shmem_bytes(),
|
||||
wrap.key->name(),
|
||||
wrap.key->op_name());
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
struct KernelHash {
|
||||
size_t operator()(const KernelReport& key) const {
|
||||
return absl::Hash<KernelKeyWrap>()(KernelKeyWrap{&key});
|
||||
}
|
||||
};
|
||||
|
||||
using KernelReportMap =
|
||||
absl::flat_hash_map<KernelReport, KernelReportValue, KernelHash,
|
||||
KernelReportEqualToComparator>;
|
||||
|
||||
// Copies reports into the given KernelStatsDb.
|
||||
void CopyKernelReportsToDb(const KernelReportMap& reports, KernelStatsDb* dst);
|
||||
|
||||
// Inserts or aggregates KernelReports into the given KernelReportMap.
|
||||
void InsertOrUpdateKernelReport(const KernelReport& kernel,
|
||||
const KernelReportValue& value,
|
||||
KernelReportMap* dst);
|
||||
|
||||
// Aggregates values from one KernelReportMap into another.
|
||||
void MergeKernelReports(const KernelReportMap& reports, KernelReportMap* dst);
|
||||
|
||||
// Groups KernelReport in <kernel_stats_db> by tensorflow operation name.
|
||||
absl::flat_hash_map<absl::string_view, std::vector<const KernelReport*>>
|
||||
|
Loading…
x
Reference in New Issue
Block a user