add GetContextId ,GetStreamIdEx to CuptiInterface to avoid direct cupti function call.
PiperOrigin-RevId: 272511603
This commit is contained in:
parent
9444e3a464
commit
71242dbfb6
@ -173,6 +173,12 @@ class CuptiInterface {
|
||||
|
||||
virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0;
|
||||
|
||||
virtual CUptiResult GetContextId(CUcontext context, uint32_t* context_id) = 0;
|
||||
|
||||
virtual CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
|
||||
uint8_t per_thread_stream,
|
||||
uint32_t* stream_id) = 0;
|
||||
|
||||
// Interface maintenance functions. Not directly related to CUPTI, but
|
||||
// required for implementing an error resilient layer over CUPTI API.
|
||||
|
||||
|
@ -811,7 +811,8 @@ class CudaEventRecorder {
|
||||
|
||||
if (it == context_infos_.end()) {
|
||||
uint32 context_id = 0;
|
||||
RETURN_IF_CUPTI_ERROR(cuptiGetContextId(context, &context_id));
|
||||
RETURN_IF_CUPTI_ERROR(
|
||||
cupti_interface_->GetContextId(context, &context_id));
|
||||
ContextInfo ctx_info = {context_id};
|
||||
it = context_infos_.emplace(context, ctx_info).first;
|
||||
}
|
||||
@ -838,9 +839,11 @@ class CudaEventRecorder {
|
||||
int index = stream ? ++ctx_info->num_streams : 0;
|
||||
uint32 stream_id = 0;
|
||||
#if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM)
|
||||
RETURN_IF_CUPTI_ERROR(cuptiGetStreamIdEx(context, stream, 1, &stream_id));
|
||||
RETURN_IF_CUPTI_ERROR(
|
||||
cupti_interface_->GetStreamIdEx(context, stream, 1, &stream_id));
|
||||
#else
|
||||
RETURN_IF_CUPTI_ERROR(cuptiGetStreamIdEx(context, stream, 0, &stream_id));
|
||||
RETURN_IF_CUPTI_ERROR(
|
||||
cupti_interface_->GetStreamIdEx(context, stream, 0, &stream_id));
|
||||
#endif
|
||||
|
||||
StreamInfo stream_info = {stream_id, static_cast<std::string>(name), index,
|
||||
|
@ -233,5 +233,16 @@ CUptiResult CuptiWrapper::GetResultString(CUptiResult result,
|
||||
return cuptiGetResultString(result, str);
|
||||
}
|
||||
|
||||
CUptiResult CuptiWrapper::GetContextId(CUcontext context,
|
||||
uint32_t* context_id) {
|
||||
return cuptiGetContextId(context, context_id);
|
||||
}
|
||||
|
||||
CUptiResult CuptiWrapper::GetStreamIdEx(CUcontext context, CUstream stream,
|
||||
uint8_t per_thread_stream,
|
||||
uint32_t* stream_id) {
|
||||
return cuptiGetStreamIdEx(context, stream, per_thread_stream, stream_id);
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
@ -166,6 +166,12 @@ class CuptiWrapper : public tensorflow::profiler::CuptiInterface {
|
||||
|
||||
CUptiResult GetResultString(CUptiResult result, const char** str) override;
|
||||
|
||||
CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override;
|
||||
|
||||
CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
|
||||
uint8_t per_thread_stream,
|
||||
uint32_t* stream_id) override;
|
||||
|
||||
void CleanUp() override {}
|
||||
bool Disabled() const override { return false; }
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user