diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 51797def041..32fce2bf94e 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto, return status; } +void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) { + mutex_lock lock(mu_); + sync_on_completion_ = sync_on_completion; +} + +bool XlaDevice::RequiresSyncOnCompletion() const { + mutex_lock lock(mu_); + return sync_on_completion_; +} + XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, const char* jit_device) { // Any op assigned to the device that isn't rewritten by the graph rewriter diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 92891ffa8c6..0f06b3fc80b 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice { // information for GPU and TPU devices. Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_); + // Instructs this XlaDevice to return 'sync_on_completion' for + // RequiresSyncOnCompletion(). + void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_); + + bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_); + private: xla::LocalClient* client() const; Allocator* GetAllocatorLocked(AllocatorAttributes attr) @@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice { static Status GetMetadataFromDevice(DeviceBase* device, const XlaDevice::Metadata** metadata); - mutex mu_; + mutable mutex mu_; // The metadata of this XlaDevice. const Metadata xla_metadata_; // Which hardware device in the client's platform this XlaDevice controls. @@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice { // Thread pool used for running closures std::unique_ptr thread_pool_; + + // True if the device requires XlaDevice::Sync to be called on completion + // regardless of status. + bool sync_on_completion_ GUARDED_BY(mu_) = false; }; // Builds OpKernel registrations on 'device' for the JIT operators diff --git a/tensorflow/compiler/xla/service/stream_pool.cc b/tensorflow/compiler/xla/service/stream_pool.cc index 5d1cd1c4422..ec09dff9244 100644 --- a/tensorflow/compiler/xla/service/stream_pool.cc +++ b/tensorflow/compiler/xla/service/stream_pool.cc @@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) { // Re-use an existing stream from the pool. stream = std::move(streams_.back()); streams_.pop_back(); - VLOG(1) << stream->DebugStreamPointers() - << " StreamPool reusing existing stream"; + if (stream->ok()) { + VLOG(1) << stream->DebugStreamPointers() + << " StreamPool reusing existing stream"; + } else { + VLOG(1) << stream->DebugStreamPointers() + << " stream was not ok, StreamPool deleting"; + stream = nullptr; + } } } diff --git a/tensorflow/compiler/xla/service/stream_pool_test.cc b/tensorflow/compiler/xla/service/stream_pool_test.cc index aaf5c37b0d2..92f47579d31 100644 --- a/tensorflow/compiler/xla/service/stream_pool_test.cc +++ b/tensorflow/compiler/xla/service/stream_pool_test.cc @@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) { EXPECT_EQ(stream2_ptr, stream3_ptr); } +TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) { + std::unique_ptr executor = NewStreamExecutor(); + StreamPool pool; + + // Borrow a stream. + StreamPool::Ptr stream1 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream1->ok()); + + // Return the stream, but hold a handle to it. + se::Stream* stream1_ptr = stream1.get(); + stream1 = nullptr; + + // Now stream1 is back in the pool, force an error on the stream. Here we call + // a method that requires DNN support, which we know the Host platform doesn't + // support. + stream1_ptr->ThenDepthConcatenate({}, {}, nullptr); + EXPECT_FALSE(stream1_ptr->ok()); + + // Borrow stream2. + StreamPool::Ptr stream2 = pool.BorrowStream(executor.get()); + EXPECT_TRUE(stream2->ok()); + + // The underlying streams should be different. They would have been + // the same, but since we forced an error on stream1, it cannot be + // put back into the pool. Sadly we can't just check: + // EXPECT_NE(stream1_ptr, stream2_ptr); + // + // The above should hold logically, but it may fail if the new + // stream instance allocated for stream2 happens to reside in the + // same memory address as stream1, which has been deleted. + // + // The check that stream2->ok() serves as a good-enough check. +} + } // namespace } // namespace xla diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 81d68e3be49..fb76d6ac295 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -106,6 +106,10 @@ class Device : public DeviceBase { // at completion. virtual Status Sync() = 0; + // Override this to return true for devices that require a Sync() call before + // session completion. + virtual bool RequiresSyncOnCompletion() const { return false; } + // Optionally modify the device's GraphDef before execution. // // This method should be considered experimental and is supplied to enable diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index d0a0767d6ba..98719542c00 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -2301,13 +2301,15 @@ void ExecutorState::Finish() { auto done_cb = std::move(done_cb_); auto runner = std::move(runner_); mu_.unlock(); - if (sync_on_finish_ && status.ok()) { + Device* device = impl_->params_.device; + if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) { // Block until the device has finished all queued operations. For // devices like GPUs that continue to execute Ops after their Compute // methods have completed, this ensures that control is not returned to // the user until the step (and its side-effects) has actually completed. - status = impl_->params_.device->Sync(); + status.Update(device->Sync()); } + delete this; CHECK(done_cb != nullptr); runner([=]() { done_cb(status); }); diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index 1258e40c934..af59500aee3 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) { } } +bool CancellationManager::TryDeregisterCallback(CancellationToken token) { + mutex_lock lock(mu_); + if (is_cancelled_ || is_cancelling_) { + return false; + } else { + callbacks_.erase(token); + return true; + } +} + CancellationManager::~CancellationManager() { if (!callbacks_.empty()) { StartCancel(); diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h index acdaaf6a901..7a5d9424867 100644 --- a/tensorflow/core/framework/cancellation.h +++ b/tensorflow/core/framework/cancellation.h @@ -122,6 +122,15 @@ class CancellationManager { // cancellation manager. bool DeregisterCallback(CancellationToken token); + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // immediately, with no guarantee that the callback has completed. + // + // This method is guaranteed to return true if StartCancel has not been + // called. + bool TryDeregisterCallback(CancellationToken token); + private: bool is_cancelling_; std::atomic_bool is_cancelled_; diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc index e3f18240b58..bf7593bc5f7 100644 --- a/tensorflow/core/framework/cancellation_test.cc +++ b/tensorflow/core/framework/cancellation_test.cc @@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) { delete cm; } +TEST(Cancellation, TryDeregisterWithoutCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_TRUE(deregistered); + delete manager; + EXPECT_FALSE(is_cancelled); +} + +TEST(Cancellation, TryDeregisterAfterCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_FALSE(deregistered); + delete manager; +} + +TEST(Cancellation, TryDeregisterDuringCancel) { + Notification cancel_started, finish_callback, cancel_complete; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback(token, [&]() { + cancel_started.Notify(); + finish_callback.WaitForNotification(); + }); + EXPECT_TRUE(registered); + + thread::ThreadPool w(Env::Default(), "test", 1); + w.Schedule([&]() { + manager->StartCancel(); + cancel_complete.Notify(); + }); + cancel_started.WaitForNotification(); + + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_FALSE(deregistered); + + finish_callback.Notify(); + cancel_complete.WaitForNotification(); + delete manager; +} + } // namespace tensorflow