add GetContextId ,GetStreamIdEx to CuptiInterface to avoid direct cupti function call.

PiperOrigin-RevId: 272511603
This commit is contained in:
A. Unique TensorFlower 2019-10-02 13:51:44 -07:00 committed by TensorFlower Gardener
parent 9444e3a464
commit 71242dbfb6
4 changed files with 29 additions and 3 deletions

View File

@ -173,6 +173,12 @@ class CuptiInterface {
virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0;
virtual CUptiResult GetContextId(CUcontext context, uint32_t* context_id) = 0;
virtual CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
uint8_t per_thread_stream,
uint32_t* stream_id) = 0;
// Interface maintenance functions. Not directly related to CUPTI, but
// required for implementing an error resilient layer over CUPTI API.

View File

@ -811,7 +811,8 @@ class CudaEventRecorder {
if (it == context_infos_.end()) {
uint32 context_id = 0;
RETURN_IF_CUPTI_ERROR(cuptiGetContextId(context, &context_id));
RETURN_IF_CUPTI_ERROR(
cupti_interface_->GetContextId(context, &context_id));
ContextInfo ctx_info = {context_id};
it = context_infos_.emplace(context, ctx_info).first;
}
@ -838,9 +839,11 @@ class CudaEventRecorder {
int index = stream ? ++ctx_info->num_streams : 0;
uint32 stream_id = 0;
#if defined(CUDA_API_PER_THREAD_DEFAULT_STREAM)
RETURN_IF_CUPTI_ERROR(cuptiGetStreamIdEx(context, stream, 1, &stream_id));
RETURN_IF_CUPTI_ERROR(
cupti_interface_->GetStreamIdEx(context, stream, 1, &stream_id));
#else
RETURN_IF_CUPTI_ERROR(cuptiGetStreamIdEx(context, stream, 0, &stream_id));
RETURN_IF_CUPTI_ERROR(
cupti_interface_->GetStreamIdEx(context, stream, 0, &stream_id));
#endif
StreamInfo stream_info = {stream_id, static_cast<std::string>(name), index,

View File

@ -233,5 +233,16 @@ CUptiResult CuptiWrapper::GetResultString(CUptiResult result,
return cuptiGetResultString(result, str);
}
CUptiResult CuptiWrapper::GetContextId(CUcontext context,
uint32_t* context_id) {
return cuptiGetContextId(context, context_id);
}
CUptiResult CuptiWrapper::GetStreamIdEx(CUcontext context, CUstream stream,
uint8_t per_thread_stream,
uint32_t* stream_id) {
return cuptiGetStreamIdEx(context, stream, per_thread_stream, stream_id);
}
} // namespace profiler
} // namespace tensorflow

View File

@ -166,6 +166,12 @@ class CuptiWrapper : public tensorflow::profiler::CuptiInterface {
CUptiResult GetResultString(CUptiResult result, const char** str) override;
CUptiResult GetContextId(CUcontext context, uint32_t* context_id) override;
CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
uint8_t per_thread_stream,
uint32_t* stream_id) override;
void CleanUp() override {}
bool Disabled() const override { return false; }