add GetContextId ,GetStreamIdEx to CuptiInterface to avoid direct cupti function call.
PiperOrigin-RevId: 272511603
This commit is contained in:
parent
9444e3a464
commit
71242dbfb6
@ -173,6 +173,12 @@ class CuptiInterface {
|
|||||||
|
|
||||||
virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0;
|
virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0;
|
||||||
|
|
||||||
|
virtual CUptiResult GetContextId(CUcontext context, uint32_t* context_id) = 0;
|
||||||
|
|
||||||
|
virtual CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
|
||||||
|
uint8_t per_thread_stream,
|
||||||
|
uint32_t* stream_id) = 0;
|
||||||
|
|
||||||
// Interface maintenance functions. Not directly related to CUPTI, but
|
// Interface maintenance functions. Not directly related to CUPTI, but
|
||||||
// required for implementing an error resilient layer over CUPTI API.
|
// required for implementing an error resilient layer over CUPTI API.
|
||||||
|
|
||||||
|
@ -811,7 +811,8 @@ class CudaEventRecorder {
|
|||||||
|
|
||||||
if (it == context_infos_.end()) {
|
if (it == context_infos_.end()) {
|
||||||
uint32 context_id = 0;
|
uint32 context_id = 0;
|
||||||
RETURN_IF_CUPTI_ERROR(cuptiGetContextId(context, &context_id));
|
RETURN_IF_CUPTI_ERROR(
|
||||||
|
cupti_interface_->GetContextId(context, &context_id));
|
||||||
ContextInfo ctx_info = {context_id};
|
ContextInfo ctx_info = {context_id};
|
||||||
it = context_infos_.emplace(context, ctx_info).first;
|
it = context_infos_.emplace(context, ctx_info).first;
|
||||||
}
|
}
|
||||||
@ -838,9 +839,11 @@ class CudaEventRecorder {
|
|||||||
int index = stream ? ++ctx_info->num_streams : 0;
|
int index = stream ? ++ctx_info->num_streams : 0;
|
||||||
uint32 stream_id = 0;
|
uint32 stream_id = 0;
|
||||||
#if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM)
|
#if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM)
|
||||||
RETURN_IF_CUPTI_ERROR(cuptiGetStreamIdEx(context, stream, 1, &stream_id));
|
RETURN_IF_CUPTI_ERROR(
|
||||||
|
cupti_interface_->GetStreamIdEx(context, stream, 1, &stream_id));
|
||||||
#else
|
#else
|
||||||
RETURN_IF_CUPTI_ERROR(cuptiGetStreamIdEx(context, stream, 0, &stream_id));
|
RETURN_IF_CUPTI_ERROR(
|
||||||
|
cupti_interface_->GetStreamIdEx(context, stream, 0, &stream_id));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
StreamInfo stream_info = {stream_id, static_cast<std::string>(name), index,
|
StreamInfo stream_info = {stream_id, static_cast<std::string>(name), index,
|
||||||
|
@ -233,5 +233,16 @@ CUptiResult CuptiWrapper::GetResultString(CUptiResult result,
|
|||||||
return cuptiGetResultString(result, str);
|
return cuptiGetResultString(result, str);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CUptiResult CuptiWrapper::GetContextId(CUcontext context,
|
||||||
|
uint32_t* context_id) {
|
||||||
|
return cuptiGetContextId(context, context_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
CUptiResult CuptiWrapper::GetStreamIdEx(CUcontext context, CUstream stream,
|
||||||
|
uint8_t per_thread_stream,
|
||||||
|
uint32_t* stream_id) {
|
||||||
|
return cuptiGetStreamIdEx(context, stream, per_thread_stream, stream_id);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace profiler
|
} // namespace profiler
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -166,6 +166,12 @@ class CuptiWrapper : public tensorflow::profiler::CuptiInterface {
|
|||||||
|
|
||||||
CUptiResult GetResultString(CUptiResult result, const char** str) override;
|
CUptiResult GetResultString(CUptiResult result, const char** str) override;
|
||||||
|
|
||||||
|
CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override;
|
||||||
|
|
||||||
|
CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
|
||||||
|
uint8_t per_thread_stream,
|
||||||
|
uint32_t* stream_id) override;
|
||||||
|
|
||||||
void CleanUp() override {}
|
void CleanUp() override {}
|
||||||
bool Disabled() const override { return false; }
|
bool Disabled() const override { return false; }
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user