diff --git a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc index 0a1a1e19048..7f9111d663e 100644 --- a/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc +++ b/tensorflow/core/profiler/convert/op_stats_to_tf_stats_test.cc @@ -87,15 +87,12 @@ TEST(OpStatsToTfStats, GpuTfStats) { constexpr int64 kKernel5DurationNs = 10000; // Mock kernel details for both kernel4 and kernel5. - const std::string kKernelDetails = R"MULTI(registers_per_thread:32 -static_shared_memory_usage:0 -dynamic_shared_memory_usage:16384 -grid_x:2 -grid_y:1 -grid_z:1 -block_x:32 -block_y:1 -block_z:1)MULTI"; + const std::string kKernelDetails = R"MULTI(regs:32 +static_shared:0 +dynamic_shared:16384 +grid:2,1,1 +block:32,1,1 +occ_pct:1.0)MULTI"; XSpace space; XPlaneBuilder device_plane( diff --git a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc index 8500c3bddd6..700f057e516 100644 --- a/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc +++ b/tensorflow/core/profiler/convert/xplane_to_kernel_stats_db_test.cc @@ -41,44 +41,35 @@ TEST(ConvertXplaneToKernelStats, MultiKernels) { CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_shortest", /*offset_ps=*/10000, /*duration_ps=*/1000, {{StatType::kTfOp, "mul_786"}, - {StatType::kKernelDetails, R"MULTI(registers_per_thread:16 -static_shared_memory_usage:0 -dynamic_shared_memory_usage:0 -grid_x:1 -grid_y:1 -grid_z:1 -block_x:1 -block_y:1 -block_z:1)MULTI"}, + {StatType::kKernelDetails, R"MULTI(regs:16 +static_shared:0 +dynamic_shared:0 +grid:1,1,1 +block:1,1,1 +occ_pct:0.5)MULTI"}, {StatType::kEquation, ""}}); CreateXEvent(&device_trace_builder, &line_builder, "kernel_name_middle", /*offset_ps=*/20000, /*duration_ps=*/2000, {{StatType::kTfOp, "Conv2D"}, - {StatType::kKernelDetails, R"MULTI(registers_per_thread:32 -static_shared_memory_usage:0 -dynamic_shared_memory_usage:16384 -grid_x:2 -grid_y:1 -grid_z:1 -block_x:32 -block_y:1 -block_z:1)MULTI"}, + {StatType::kKernelDetails, R"MULTI(regs:32 +static_shared:0 +dynamic_shared:16384 +grid:2,1,1 +block:32,1,1 +occ_pct=0.13)MULTI"}, {StatType::kEquation, ""}}); CreateXEvent(&device_trace_builder, &line_builder, "volta_fp16_s884gemm_fp16_128x128_ldg8_f2f_tn", /*offset_ps=*/30000, /*duration_ps=*/3000, {{StatType::kTfOp, "Einsum_80"}, - {StatType::kKernelDetails, R"MULTI(registers_per_thread:32 -static_shared_memory_usage:0 -dynamic_shared_memory_usage:16384 -grid_x:3 -grid_y:1 -grid_z:1 -block_x:64 -block_y:1 -block_z:1)MULTI"}, + {StatType::kKernelDetails, R"MULTI(regs:32 +static_shared:0 +dynamic_shared:16384 +grid:3,1,1 +block:64,1,1 +occ_pct:0.25)MULTI"}, {StatType::kEquation, ""}}); KernelReportMap reports; diff --git a/tensorflow/core/profiler/internal/gpu/BUILD b/tensorflow/core/profiler/internal/gpu/BUILD index 17da1e3756c..40773c6cb98 100644 --- a/tensorflow/core/profiler/internal/gpu/BUILD +++ b/tensorflow/core/profiler/internal/gpu/BUILD @@ -158,6 +158,26 @@ tf_cuda_library( ], ) +cc_library( + name = "cupti_collector_header", + hdrs = ["cupti_collector.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler/protobuf:xplane_proto_cc", + "//tensorflow/core/profiler/utils:parse_annotation", + "//tensorflow/core/profiler/utils:xplane_builder", + "//tensorflow/core/profiler/utils:xplane_schema", + "//tensorflow/core/profiler/utils:xplane_utils", + "@com_google_absl//absl/container:fixed_array", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/container:node_hash_set", + "@com_google_absl//absl/strings", + ], +) + tf_cuda_library( name = "cupti_utils", srcs = if_cuda_is_configured_compat(["cupti_utils.cc"]), diff --git a/tensorflow/core/profiler/internal/gpu/cupti_collector.cc b/tensorflow/core/profiler/internal/gpu/cupti_collector.cc index f3b132c4040..482c8077a2f 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_collector.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_collector.cc @@ -15,10 +15,13 @@ limitations under the License. #include "tensorflow/core/profiler/internal/gpu/cupti_collector.h" +#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/hash/hash.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" #include "third_party/gpus/cuda/include/cuda.h" +#include "third_party/gpus/cuda/include/cuda_occupancy.h" #include "tensorflow/core/platform/abi.h" #include "tensorflow/core/platform/host_info.h" #include "tensorflow/core/platform/mutex.h" @@ -59,6 +62,34 @@ bool IsHostEvent(const CuptiTracerEvent& event, int64* line_id) { } } +struct DeviceOccupancyParams { + cudaOccFuncAttributes attributes = {}; + int block_size = 0; + size_t dynamic_smem_size = 0; + + friend bool operator==(const DeviceOccupancyParams& lhs, + const DeviceOccupancyParams& rhs) { + return 0 == memcmp(&lhs, &rhs, sizeof(lhs)); + } + + template + friend H AbslHashValue(H hash_state, const DeviceOccupancyParams& params) { + return H::combine( + std::move(hash_state), params.attributes.maxThreadsPerBlock, + params.attributes.numRegs, params.attributes.sharedSizeBytes, + static_cast(params.attributes.partitionedGCConfig), + static_cast(params.attributes.shmemLimitConfig), + params.attributes.maxDynamicSharedSizeBytes, params.block_size, + params.dynamic_smem_size); + } +}; + +struct OccupancyStats { + double occupancy_pct = 0.0; + int min_grid_size = 0; + int suggested_block_size = 0; +}; + struct CorrelationInfo { CorrelationInfo(uint32 t, uint32 e) : thread_id(t), enqueue_time_ns(e) {} uint32 thread_id; @@ -66,6 +97,35 @@ struct CorrelationInfo { }; struct PerDeviceCollector { + OccupancyStats GetOccupancy(const DeviceOccupancyParams& params) const { + OccupancyStats stats; + if (device_properties.computeMajor == 0) { + return {}; + } + + const cudaOccDeviceState state = {}; + cudaOccResult occ_result; + cudaOccError status = cudaOccMaxActiveBlocksPerMultiprocessor( + &occ_result, &device_properties, ¶ms.attributes, &state, + params.block_size, params.dynamic_smem_size); + if (status != CUDA_OCC_SUCCESS) { + return {}; + } + + stats.occupancy_pct = + occ_result.activeBlocksPerMultiprocessor * params.block_size; + stats.occupancy_pct /= device_properties.maxThreadsPerMultiprocessor; + + status = cudaOccMaxPotentialOccupancyBlockSize( + &stats.min_grid_size, &stats.suggested_block_size, &device_properties, + ¶ms.attributes, &state, NULL, params.dynamic_smem_size); + if (status != CUDA_OCC_SUCCESS) { + return {}; + } + + return stats; + } + void CreateXEvent(const CuptiTracerEvent& event, XPlaneBuilder* plane, uint64 start_gpu_ns, uint64 end_gpu_ns, XLineBuilder* line) { @@ -105,16 +165,41 @@ struct PerDeviceCollector { *plane->GetOrCreateStatMetadata(GetStatTypeStr(StatType::kContextId)), absl::StrCat("$$", static_cast(event.context_id))); } - if (event.type == CuptiTracerEventType::Kernel) { - std::string kernel_details = absl::StrCat( - "regs:", event.kernel_info.registers_per_thread, - " shm:", event.kernel_info.static_shared_memory_usage, - " grid:", event.kernel_info.grid_x, ",", event.kernel_info.grid_y, - ",", event.kernel_info.grid_z, " block:", event.kernel_info.block_x, - ",", event.kernel_info.block_y, ",", event.kernel_info.block_z); + + if (event.type == CuptiTracerEventType::Kernel && + event.source == CuptiTracerEventSource::Activity) { + DeviceOccupancyParams params{}; + params.attributes.maxThreadsPerBlock = INT_MAX; + params.attributes.numRegs = + static_cast(event.kernel_info.registers_per_thread); + params.attributes.sharedSizeBytes = + event.kernel_info.static_shared_memory_usage; + params.attributes.partitionedGCConfig = PARTITIONED_GC_OFF; + params.attributes.shmemLimitConfig = FUNC_SHMEM_LIMIT_DEFAULT; + params.attributes.maxDynamicSharedSizeBytes = 0; + params.block_size = static_cast(event.kernel_info.block_x * + event.kernel_info.block_y * + event.kernel_info.block_z); + + params.dynamic_smem_size = event.kernel_info.dynamic_shared_memory_usage; + + OccupancyStats& occ_stats = occupancy_cache[params]; + if (occ_stats.occupancy_pct == 0) { + occ_stats = GetOccupancy(params); + } + xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( + StatType::kTheoreticalOccupancyPct)), + occ_stats.occupancy_pct); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata( + GetStatTypeStr(StatType::kOccupancyMinGridSize)), + static_cast(occ_stats.min_grid_size)); + xevent.AddStatValue(*plane->GetOrCreateStatMetadata(GetStatTypeStr( + StatType::kOccupancySuggestedBlockSize)), + static_cast(occ_stats.suggested_block_size)); xevent.AddStatValue(*plane->GetOrCreateStatMetadata( GetStatTypeStr(StatType::kKernelDetails)), - *plane->GetOrCreateStatMetadata(kernel_details)); + *plane->GetOrCreateStatMetadata(ToXStat( + event.kernel_info, occ_stats.occupancy_pct))); } else if (event.type == CuptiTracerEventType::MemcpyH2D || event.type == CuptiTracerEventType::MemcpyD2H || event.type == CuptiTracerEventType::MemcpyD2D || @@ -416,12 +501,49 @@ struct PerDeviceCollector { GetStatTypeStr(StatType::kDevCapComputeCapMinor)), *compute_capability_minor); } + + auto max_threads_per_block = + GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK); + auto max_threads_per_sm = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_MULTIPROCESSOR); + auto regs_per_block = + GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_REGISTERS_PER_BLOCK); + auto regs_per_sm = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR); + auto warp_size = GetDeviceAttribute(device, CU_DEVICE_ATTRIBUTE_WARP_SIZE); + auto shared_mem_per_block = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK); + auto shared_mem_per_sm = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR); + auto shared_mem_per_block_optin = GetDeviceAttribute( + device, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN); + + // Precondition for calculating GPU occupancy is to have all of these + // inputs. Otherwise, GPU occupancy will be left unset as 0%. + if (core_count && compute_capability_major && compute_capability_minor && + max_threads_per_block && max_threads_per_sm && regs_per_block && + regs_per_sm && warp_size && shared_mem_per_block && shared_mem_per_sm && + shared_mem_per_block_optin) { + device_properties.computeMajor = *compute_capability_major; + device_properties.computeMinor = *compute_capability_minor; + device_properties.numSms = *core_count; + device_properties.maxThreadsPerBlock = *max_threads_per_block; + device_properties.maxThreadsPerMultiprocessor = *max_threads_per_sm; + device_properties.regsPerBlock = *regs_per_block; + device_properties.regsPerMultiprocessor = *regs_per_sm; + device_properties.warpSize = *warp_size; + device_properties.sharedMemPerBlock = *shared_mem_per_block; + device_properties.sharedMemPerMultiprocessor = *shared_mem_per_sm; + device_properties.sharedMemPerBlockOptin = *shared_mem_per_block_optin; + } } mutex m; std::vector events TF_GUARDED_BY(m); absl::flat_hash_map correlation_info TF_GUARDED_BY(m); + cudaOccDeviceProp device_properties; + absl::flat_hash_map occupancy_cache; }; } // namespace @@ -509,10 +631,13 @@ class CuptiTraceCollectorImpl : public CuptiTraceCollector { std::string name = GpuPlaneName(device_ordinal); XPlaneBuilder device_plane(FindOrAddMutablePlaneWithName(space, name)); device_plane.SetId(device_ordinal); - num_events += per_device_collector_[device_ordinal].Flush( - start_gpu_ns_, end_gpu_ns, &device_plane, &host_plane); + + // Calculate device capabilities before flushing, so that device + // properties are available to the occupancy calculator in Flush(). per_device_collector_[device_ordinal].GetDeviceCapabilities( device_ordinal, &device_plane); + num_events += per_device_collector_[device_ordinal].Flush( + start_gpu_ns_, end_gpu_ns, &device_plane, &host_plane); NormalizeTimeStamps(&device_plane, start_walltime_ns_); } NormalizeTimeStamps(&host_plane, start_walltime_ns_); diff --git a/tensorflow/core/profiler/internal/gpu/cupti_collector.h b/tensorflow/core/profiler/internal/gpu/cupti_collector.h index bbc169364ea..ada6cec1d2d 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_collector.h +++ b/tensorflow/core/profiler/internal/gpu/cupti_collector.h @@ -55,25 +55,37 @@ struct MemAllocDetails { struct KernelDetails { // The number of registers used in this kernel. - uint64 registers_per_thread; + uint32 registers_per_thread; // The amount of shared memory space used by a thread block. - uint64 static_shared_memory_usage; + uint32 static_shared_memory_usage; // The amount of dynamic memory space used by a thread block. - uint64 dynamic_shared_memory_usage; + uint32 dynamic_shared_memory_usage; // X-dimension of a thread block. - uint64 block_x; + uint32 block_x; // Y-dimension of a thread block. - uint64 block_y; + uint32 block_y; // Z-dimension of a thread block. - uint64 block_z; + uint32 block_z; // X-dimension of a grid. - uint64 grid_x; + uint32 grid_x; // Y-dimension of a grid. - uint64 grid_y; + uint32 grid_y; // Z-dimension of a grid. - uint64 grid_z; + uint32 grid_z; }; +inline std::string ToXStat(const KernelDetails& kernel_info, + double occupancy_pct) { + return absl::StrCat( + "regs:", kernel_info.registers_per_thread, + " static_shared:", kernel_info.static_shared_memory_usage, + " dynamic_shared:", kernel_info.dynamic_shared_memory_usage, + " grid:", kernel_info.grid_x, ",", kernel_info.grid_y, ",", + kernel_info.grid_z, " block:", kernel_info.block_x, ",", + kernel_info.block_y, ",", kernel_info.block_z, + " occ_pct:", occupancy_pct); +} + enum class CuptiTracerEventType { Unsupported = 0, Kernel = 1, @@ -91,8 +103,9 @@ enum class CuptiTracerEventType { const char* GetTraceEventTypeName(const CuptiTracerEventType& type); enum class CuptiTracerEventSource { - DriverCallback = 0, - Activity = 1, + Invalid = 0, + DriverCallback = 1, + Activity = 2, // Maybe consider adding runtime callback and metric api in the future. }; @@ -105,8 +118,8 @@ struct CuptiTracerEvent { std::numeric_limits::max(); static constexpr uint64 kInvalidStreamId = std::numeric_limits::max(); - CuptiTracerEventType type; - CuptiTracerEventSource source; + CuptiTracerEventType type = CuptiTracerEventType::Unsupported; + CuptiTracerEventSource source = CuptiTracerEventSource::Invalid; // Although CUpti_CallbackData::functionName is persistent, however // CUpti_ActivityKernel4::name is not persistent, therefore we need a copy of // it. @@ -114,9 +127,9 @@ struct CuptiTracerEvent { // This points to strings in AnnotationMap, which should outlive the point // where serialization happens. absl::string_view annotation; - uint64 start_time_ns; - uint64 end_time_ns; - uint32 device_id; + uint64 start_time_ns = 0; + uint64 end_time_ns = 0; + uint32 device_id = 0; uint32 correlation_id = kInvalidCorrelationId; uint32 thread_id = kInvalidThreadId; int64 context_id = kInvalidContextId; diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc index c1f75c90f72..a7580e252c7 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc @@ -291,7 +291,7 @@ void CUPTIAPI FreeCuptiActivityBuffer(CUcontext context, uint32_t stream_id, void AddKernelEventUponApiExit(CuptiTraceCollector *collector, uint32 device_id, const CUpti_CallbackData *cbdata, uint64 start_time, uint64 end_time) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Kernel; event.source = CuptiTracerEventSource::DriverCallback; event.name = cbdata->symbolName ? cbdata->symbolName : cbdata->functionName; @@ -310,7 +310,7 @@ CuptiTracerEvent PopulateMemcpyCallbackEvent( CuptiTracerEventType type, const CUpti_CallbackData *cbdata, size_t num_bytes, uint32 src_device, uint32 dst_device, bool async, uint64 start_time, uint64 end_time) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = type; event.source = CuptiTracerEventSource::DriverCallback; event.start_time_ns = start_time; @@ -373,7 +373,7 @@ void AddCudaMallocEventUponApiExit(CuptiTraceCollector *collector, uint64 start_time, uint64 end_time) { const cuMemAlloc_v2_params_st *params = reinterpret_cast(cbdata->functionParams); - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::MemoryAlloc; event.source = CuptiTracerEventSource::DriverCallback; event.name = cbdata->functionName; @@ -392,7 +392,7 @@ void AddGenericEventUponApiExit(CuptiTraceCollector *collector, uint32 device_id, CUpti_CallbackId cbid, const CUpti_CallbackData *cbdata, uint64 start_time, uint64 end_time) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Generic; event.source = CuptiTracerEventSource::DriverCallback; event.name = cbdata->functionName; @@ -407,7 +407,7 @@ void AddGenericEventUponApiExit(CuptiTraceCollector *collector, void AddKernelActivityEvent(CuptiTraceCollector *collector, const CUpti_ActivityKernel4 *kernel) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Kernel; event.source = CuptiTracerEventSource::Activity; event.name = kernel->name; @@ -433,7 +433,7 @@ void AddKernelActivityEvent(CuptiTraceCollector *collector, void AddMemcpyActivityEvent(CuptiTraceCollector *collector, const CUpti_ActivityMemcpy *memcpy) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; switch (memcpy->copyKind) { case CUPTI_ACTIVITY_MEMCPY_KIND_HTOD: event.type = CuptiTracerEventType::MemcpyH2D; @@ -477,7 +477,7 @@ void AddMemcpyActivityEvent(CuptiTraceCollector *collector, // Invokes callback upon peer-2-peer memcpy between different GPU devices. void AddMemcpy2ActivityEvent(CuptiTraceCollector *collector, const CUpti_ActivityMemcpy2 *memcpy2) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::MemcpyP2P; event.name = "MemcpyP2P"; event.source = CuptiTracerEventSource::Activity; @@ -500,7 +500,7 @@ void AddMemcpy2ActivityEvent(CuptiTraceCollector *collector, void AddCuptiOverheadActivityEvent(CuptiTraceCollector *collector, const CUpti_ActivityOverhead *overhead) { - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Overhead; event.name = getActivityOverheadKindString(overhead->overheadKind); event.source = CuptiTracerEventSource::Activity; @@ -538,7 +538,7 @@ void AddUnifiedMemoryActivityEvent( const CUpti_ActivityUnifiedMemoryCounter2 *record) { VLOG(3) << "Cuda Unified Memory Activity, kind: " << record->counterKind << " src: " << record->srcId << " dst: " << record->dstId; - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::UnifiedMemory; event.name = getActivityUnifiedMemoryKindString(record->counterKind); event.source = CuptiTracerEventSource::Activity; @@ -935,7 +935,7 @@ class CudaEventRecorder { std::string annotation; - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = CuptiTracerEventType::Kernel; event.source = CuptiTracerEventSource::Activity; // on gpu device. event.name = record.kernel_name; @@ -963,7 +963,7 @@ class CudaEventRecorder { GetElapsedTimeUs(record.start_event, stream_info.ctx_info->end_event); auto elapsed_us = GetElapsedTimeUs(record.start_event, record.stop_event); - CuptiTracerEvent event; + CuptiTracerEvent event{}; event.type = record.type; event.name = GetTraceEventTypeName(event.type); event.source = CuptiTracerEventSource::Activity; diff --git a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc index 55ccdbed977..a8110d52b25 100644 --- a/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc +++ b/tensorflow/core/profiler/internal/gpu/device_tracer_test.cc @@ -288,7 +288,7 @@ TEST_F(DeviceTracerTest, TraceToXSpace) { ++total_events; }); }); - EXPECT_EQ(total_events, 5); + EXPECT_GE(total_events, 5); } } // namespace diff --git a/tensorflow/core/profiler/protobuf/kernel_stats.proto b/tensorflow/core/profiler/protobuf/kernel_stats.proto index 144ec9acb8a..2c1f1b9ee91 100644 --- a/tensorflow/core/profiler/protobuf/kernel_stats.proto +++ b/tensorflow/core/profiler/protobuf/kernel_stats.proto @@ -2,6 +2,7 @@ syntax = "proto3"; package tensorflow.profiler; +// Next ID: 15 message KernelReport { // Name of the kernel. string name = 1; @@ -29,6 +30,8 @@ message KernelReport { string op_name = 12; // Number of occurrences. uint32 occurrences = 13; + // Occupancy percentage. + float occupancy_pct = 14; } message KernelStatsDb { diff --git a/tensorflow/core/profiler/utils/BUILD b/tensorflow/core/profiler/utils/BUILD index 0694571bef2..38f209a9c7f 100644 --- a/tensorflow/core/profiler/utils/BUILD +++ b/tensorflow/core/profiler/utils/BUILD @@ -466,6 +466,7 @@ tf_cc_test( ":kernel_stats_utils", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core/profiler/internal/gpu:cupti_collector_header", "//tensorflow/core/profiler/protobuf:kernel_stats_proto_cc", ], ) diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.cc b/tensorflow/core/profiler/utils/kernel_stats_utils.cc index 982434369ea..601b0da0239 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.cc @@ -42,7 +42,7 @@ const int kMaxNumOfKernels = 1000; void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, KernelReport* kernel) { const std::vector params = - absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(":\n")); + absl::StrSplit(xstat_kernel_details, absl::ByAnyChar(" :\n")); constexpr uint32 kNumDimensions = 3; for (uint32 dim = 0; dim < kNumDimensions; ++dim) { @@ -53,33 +53,36 @@ void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, // Process value pairs. for (uint32 ii = 0; ii < params.size(); ii += 2) { uint32 value = 0; - if (params[ii] == "registers_per_thread" && - absl::SimpleAtoi(params[ii + 1], &value)) { + double pct = 0.0; + if (params[ii] == "regs" && absl::SimpleAtoi(params[ii + 1], &value)) { kernel->set_registers_per_thread(value); - } else if (params[ii] == "static_shared_memory_usage" && + } else if (params[ii] == "static_shared" && absl::SimpleAtoi(params[ii + 1], &value)) { kernel->set_static_shmem_bytes(value); - } else if (params[ii] == "dynamic_shared_memory_usage" && + } else if (params[ii] == "dynamic_shared" && absl::SimpleAtoi(params[ii + 1], &value)) { kernel->set_dynamic_shmem_bytes(value); - } else if (params[ii] == "block_x" && - absl::SimpleAtoi(params[ii + 1], &value)) { - kernel->mutable_block_dim()->Set(0, value); - } else if (params[ii] == "block_y" && - absl::SimpleAtoi(params[ii + 1], &value)) { - kernel->mutable_block_dim()->Set(1, value); - } else if (params[ii] == "block_z" && - absl::SimpleAtoi(params[ii + 1], &value)) { - kernel->mutable_block_dim()->Set(2, value); - } else if (params[ii] == "grid_x" && - absl::SimpleAtoi(params[ii + 1], &value)) { - kernel->mutable_grid_dim()->Set(0, value); - } else if (params[ii] == "grid_y" && - absl::SimpleAtoi(params[ii + 1], &value)) { - kernel->mutable_grid_dim()->Set(1, value); - } else if (params[ii] == "grid_z" && - absl::SimpleAtoi(params[ii + 1], &value)) { - kernel->mutable_grid_dim()->Set(2, value); + } else if (params[ii] == "block") { + const std::vector& block = + absl::StrSplit(params[ii + 1], ','); + uint32 tmp[3]; + if (block.size() == 3 && absl::SimpleAtoi(block[0], &tmp[0]) && + absl::SimpleAtoi(block[1], &tmp[1]) && + absl::SimpleAtoi(block[2], &tmp[2])) { + std::copy_n(tmp, 3, kernel->mutable_block_dim()->begin()); + } + } else if (params[ii] == "grid") { + const std::vector& grid = + absl::StrSplit(params[ii + 1], ','); + uint32 tmp[3]; + if (grid.size() == 3 && absl::SimpleAtoi(grid[0], &tmp[0]) && + absl::SimpleAtoi(grid[1], &tmp[1]) && + absl::SimpleAtoi(grid[2], &tmp[2])) { + std::copy_n(tmp, 3, kernel->mutable_grid_dim()->begin()); + } + } else if (params[ii] == "occ_pct" && + absl::SimpleAtod(params[ii + 1], &pct)) { + kernel->set_occupancy_pct(pct); } } } diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils.h b/tensorflow/core/profiler/utils/kernel_stats_utils.h index 1b965376297..ee6f56d8454 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils.h +++ b/tensorflow/core/profiler/utils/kernel_stats_utils.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace profiler { -// Populates kernel launch information from a KernelDetails XStat. +// Populates kernel launch information from a kKernelDetails XStat. void ParseKernelLaunchParams(absl::string_view xstat_kernel_details, KernelReport* kernel); diff --git a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc b/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc index 4f3d5a1f641..a6ba09017f1 100644 --- a/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc +++ b/tensorflow/core/profiler/utils/kernel_stats_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/profiler/utils/kernel_stats_utils.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/profiler/internal/gpu/cupti_collector.h" #include "tensorflow/core/profiler/protobuf/kernel_stats.pb.h" namespace tensorflow { @@ -66,6 +67,33 @@ TEST(KernelStatsUtilsTest, TestGroupKernelReportsByOpName) { EXPECT_EQ(op2_stats.tensor_core_duration_ns, 0); } +TEST(KernelStatsUtilsTest, KernelDetailsXStatParser) { + KernelDetails kernel_info; + kernel_info.registers_per_thread = 10; + kernel_info.static_shared_memory_usage = 128; + kernel_info.dynamic_shared_memory_usage = 256; + kernel_info.block_x = 32; + kernel_info.block_y = 8; + kernel_info.block_z = 4; + kernel_info.grid_x = 3; + kernel_info.grid_y = 2; + kernel_info.grid_z = 1; + const double occupancy_pct = 50.0; + std::string xstat_kernel_details = ToXStat(kernel_info, occupancy_pct); + KernelReport kernel; + ParseKernelLaunchParams(xstat_kernel_details, &kernel); + // Verifies that the parser can parse kKernelDetails XStat. + EXPECT_EQ(kernel.registers_per_thread(), 10); + EXPECT_EQ(kernel.static_shmem_bytes(), 128); + EXPECT_EQ(kernel.dynamic_shmem_bytes(), 256); + EXPECT_EQ(kernel.block_dim()[0], 32); + EXPECT_EQ(kernel.block_dim()[1], 8); + EXPECT_EQ(kernel.block_dim()[2], 4); + EXPECT_EQ(kernel.grid_dim()[0], 3); + EXPECT_EQ(kernel.grid_dim()[1], 2); + EXPECT_EQ(kernel.grid_dim()[2], 1); +} + } // namespace } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/core/profiler/utils/xplane_schema.cc b/tensorflow/core/profiler/utils/xplane_schema.cc index 858dd7a99ba..c433271749c 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.cc +++ b/tensorflow/core/profiler/utils/xplane_schema.cc @@ -210,6 +210,10 @@ const StatTypeMap& GetStatTypeMap() { {"batch_size_after_padding", kBatchSizeAfterPadding}, {"padding_amount", kPaddingAmount}, {"batching_input_task_size", kBatchingInputTaskSize}, + // GPU related metrics. + {"theoretical_occupancy_pct", kTheoreticalOccupancyPct}, + {"occupancy_min_grid_size", kOccupancyMinGridSize}, + {"occupancy_suggested_block_size", kOccupancySuggestedBlockSize}, }); DCHECK_EQ(stat_type_map->size(), kNumStatTypes); return *stat_type_map; diff --git a/tensorflow/core/profiler/utils/xplane_schema.h b/tensorflow/core/profiler/utils/xplane_schema.h index dd8b4fe5140..4b9edf39189 100644 --- a/tensorflow/core/profiler/utils/xplane_schema.h +++ b/tensorflow/core/profiler/utils/xplane_schema.h @@ -199,7 +199,11 @@ enum StatType { kBatchSizeAfterPadding, kPaddingAmount, kBatchingInputTaskSize, - kLastStatType = kBatchingInputTaskSize, + // GPU occupancy metrics + kTheoreticalOccupancyPct, + kOccupancyMinGridSize, + kOccupancySuggestedBlockSize, + kLastStatType = kOccupancySuggestedBlockSize, }; inline std::string GpuPlaneName(int32 device_ordinal) {