diff --git a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc index 0996f8cee0b..44b2106c7a6 100644 --- a/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc +++ b/tensorflow/core/profiler/internal/gpu/cupti_tracer.cc @@ -676,9 +676,9 @@ class CudaEventRecorder { // Registers the start of a kernel launch. The returned index should be passed // to StopKernel() after the kernel launch has completed. + template size_t StartKernel(const char *kernel_name, CUcontext context, - uint32 correlation_id, - const cuLaunchKernel_params *params) { + uint32 correlation_id, const T *params) { CUstream stream = params->hStream; KernelRecord record = {kernel_name, context, stream, correlation_id}; record.details.registers_per_thread = 0; // unknown. @@ -968,10 +968,20 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { DCHECK_NE(cbdata->symbolName, nullptr); auto params = static_cast(cbdata->functionParams); - *cbdata->correlationData = recorder->StartKernel( + *cbdata->correlationData = recorder->StartKernel( cbdata->symbolName, cbdata->context, cbdata->correlationId, params); break; } + case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernel: { + DCHECK_NE(cbdata->symbolName, nullptr); + auto params = static_cast( + cbdata->functionParams); + *cbdata->correlationData = + recorder->StartKernel( + cbdata->symbolName, cbdata->context, cbdata->correlationId, + params); + break; + } case CUPTI_DRIVER_TRACE_CBID_cuMemcpy: { auto params = static_cast(cbdata->functionParams); @@ -1010,6 +1020,10 @@ class CuptiDriverApiHookWithCudaEvent : public CuptiDriverApiHook { StartMemcpyAsync( CuptiTracerEventType::MemcpyD2D, cbdata, recorder); break; + case CUPTI_DRIVER_TRACE_CBID_cuLaunchCooperativeKernelMultiDevice: + // TODO: track these kind of events. + VLOG(1) << "untracked cuLaunchCooperativeKernelMultiDevice"; + break; default: VLOG(1) << "Unexpected callback id: " << cbid; break;