Change Device::Sync to explicitly sync with all GPU streams
And not do a cuda context sync. This helps when we have virtual GPUs -- with virtual GPUs doing a context synchronize blocks on work on all virtual GPUs sharing the same physical GPU, which we don't want. PiperOrigin-RevId: 339053908 Change-Id: Ieae0cb48aa3c4dae3efc57775c30941b6bdcb3db
This commit is contained in:
parent
b8f5303051
commit
74e6ab7dc1
@ -601,11 +601,7 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
|
||||
|
||||
// Based on the semantics of Device::Sync this call should wait for
|
||||
// all streams not just the current one.
|
||||
Status BaseGPUDevice::Sync() {
|
||||
return tensorflow_gpu_device_info()
|
||||
->stream->parent()
|
||||
->BlockHostUntilAllStreamsAreDone();
|
||||
}
|
||||
Status BaseGPUDevice::Sync() { return GPUUtil::SyncAll(this); }
|
||||
|
||||
void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
|
||||
OpKernelContext* context,
|
||||
|
@ -768,7 +768,6 @@ bool StreamExecutor::AllocateStream(Stream *stream) {
|
||||
return false;
|
||||
}
|
||||
|
||||
RegisterStream(stream);
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -776,7 +775,6 @@ void StreamExecutor::DeallocateStream(Stream *stream) {
|
||||
implementation_->DeallocateStream(stream);
|
||||
CHECK_GE(live_stream_count_.fetch_sub(1), 0)
|
||||
<< "live stream count should not dip below zero";
|
||||
UnregisterStream(stream);
|
||||
}
|
||||
|
||||
bool StreamExecutor::CreateStreamDependency(Stream *dependent, Stream *other) {
|
||||
|
@ -528,24 +528,6 @@ class StreamExecutor {
|
||||
// allocation.
|
||||
StreamExecutorMemoryAllocator *GetAllocator() { return &allocator_; }
|
||||
|
||||
// Block host until all streams associated with this stream executor have
|
||||
// finished all of enqueued work.
|
||||
port::Status BlockHostUntilAllStreamsAreDone() {
|
||||
std::vector<Stream *> streams;
|
||||
{
|
||||
absl::MutexLock lock(&mu_);
|
||||
for (Stream *stream : streams_) {
|
||||
streams.push_back(stream);
|
||||
}
|
||||
}
|
||||
|
||||
for (Stream *stream : streams) {
|
||||
TF_RETURN_IF_ERROR(BlockHostUntilDone(stream));
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename BeginCallT, typename CompleteCallT, typename ReturnT,
|
||||
typename... BeginArgsT>
|
||||
@ -675,16 +657,6 @@ class StreamExecutor {
|
||||
template <typename TraceCallT, typename... ArgsT>
|
||||
void SubmitTrace(TraceCallT trace_call, ArgsT &&...args);
|
||||
|
||||
void RegisterStream(Stream *stream) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
streams_.insert(stream);
|
||||
}
|
||||
|
||||
void UnregisterStream(Stream *stream) {
|
||||
absl::MutexLock lock(&mu_);
|
||||
streams_.erase(stream);
|
||||
}
|
||||
|
||||
// Reader/writer lock for class-static StreamExecutor members.
|
||||
static absl::Mutex static_mu_;
|
||||
|
||||
@ -775,9 +747,6 @@ class StreamExecutor {
|
||||
|
||||
StreamExecutorMemoryAllocator allocator_;
|
||||
|
||||
// Set of streams associated with this stream executor.
|
||||
std::set<Stream *> streams_ TF_GUARDED_BY(mu_);
|
||||
|
||||
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
|
||||
};
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user