Internal change.
PiperOrigin-RevId: 213770000
This commit is contained in:
parent
da3357ecbd
commit
a54310b1fa
@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
return status;
|
||||
}
|
||||
|
||||
void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
|
||||
mutex_lock lock(mu_);
|
||||
sync_on_completion_ = sync_on_completion;
|
||||
}
|
||||
|
||||
bool XlaDevice::RequiresSyncOnCompletion() const {
|
||||
mutex_lock lock(mu_);
|
||||
return sync_on_completion_;
|
||||
}
|
||||
|
||||
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
|
||||
const char* jit_device) {
|
||||
// Any op assigned to the device that isn't rewritten by the graph rewriter
|
||||
|
@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice {
|
||||
// information for GPU and TPU devices.
|
||||
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
|
||||
|
||||
// Instructs this XlaDevice to return 'sync_on_completion' for
|
||||
// RequiresSyncOnCompletion().
|
||||
void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
|
||||
|
||||
bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
|
||||
|
||||
private:
|
||||
xla::LocalClient* client() const;
|
||||
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
|
||||
@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice {
|
||||
static Status GetMetadataFromDevice(DeviceBase* device,
|
||||
const XlaDevice::Metadata** metadata);
|
||||
|
||||
mutex mu_;
|
||||
mutable mutex mu_;
|
||||
// The metadata of this XlaDevice.
|
||||
const Metadata xla_metadata_;
|
||||
// Which hardware device in the client's platform this XlaDevice controls.
|
||||
@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice {
|
||||
|
||||
// Thread pool used for running closures
|
||||
std::unique_ptr<thread::ThreadPool> thread_pool_;
|
||||
|
||||
// True if the device requires XlaDevice::Sync to be called on completion
|
||||
// regardless of status.
|
||||
bool sync_on_completion_ GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
// Builds OpKernel registrations on 'device' for the JIT operators
|
||||
|
@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
|
||||
// Re-use an existing stream from the pool.
|
||||
stream = std::move(streams_.back());
|
||||
streams_.pop_back();
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " StreamPool reusing existing stream";
|
||||
if (stream->ok()) {
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " StreamPool reusing existing stream";
|
||||
} else {
|
||||
VLOG(1) << stream->DebugStreamPointers()
|
||||
<< " stream was not ok, StreamPool deleting";
|
||||
stream = nullptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) {
|
||||
EXPECT_EQ(stream2_ptr, stream3_ptr);
|
||||
}
|
||||
|
||||
TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
|
||||
std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
|
||||
StreamPool pool;
|
||||
|
||||
// Borrow a stream.
|
||||
StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
|
||||
EXPECT_TRUE(stream1->ok());
|
||||
|
||||
// Return the stream, but hold a handle to it.
|
||||
se::Stream* stream1_ptr = stream1.get();
|
||||
stream1 = nullptr;
|
||||
|
||||
// Now stream1 is back in the pool, force an error on the stream. Here we call
|
||||
// a method that requires DNN support, which we know the Host platform doesn't
|
||||
// support.
|
||||
stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
|
||||
EXPECT_FALSE(stream1_ptr->ok());
|
||||
|
||||
// Borrow stream2.
|
||||
StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
|
||||
EXPECT_TRUE(stream2->ok());
|
||||
|
||||
// The underlying streams should be different. They would have been
|
||||
// the same, but since we forced an error on stream1, it cannot be
|
||||
// put back into the pool. Sadly we can't just check:
|
||||
// EXPECT_NE(stream1_ptr, stream2_ptr);
|
||||
//
|
||||
// The above should hold logically, but it may fail if the new
|
||||
// stream instance allocated for stream2 happens to reside in the
|
||||
// same memory address as stream1, which has been deleted.
|
||||
//
|
||||
// The check that stream2->ok() serves as a good-enough check.
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -106,6 +106,10 @@ class Device : public DeviceBase {
|
||||
// at completion.
|
||||
virtual Status Sync() = 0;
|
||||
|
||||
// Override this to return true for devices that require a Sync() call before
|
||||
// session completion.
|
||||
virtual bool RequiresSyncOnCompletion() const { return false; }
|
||||
|
||||
// Optionally modify the device's GraphDef before execution.
|
||||
//
|
||||
// This method should be considered experimental and is supplied to enable
|
||||
|
@ -2301,13 +2301,15 @@ void ExecutorState::Finish() {
|
||||
auto done_cb = std::move(done_cb_);
|
||||
auto runner = std::move(runner_);
|
||||
mu_.unlock();
|
||||
if (sync_on_finish_ && status.ok()) {
|
||||
Device* device = impl_->params_.device;
|
||||
if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
|
||||
// Block until the device has finished all queued operations. For
|
||||
// devices like GPUs that continue to execute Ops after their Compute
|
||||
// methods have completed, this ensures that control is not returned to
|
||||
// the user until the step (and its side-effects) has actually completed.
|
||||
status = impl_->params_.device->Sync();
|
||||
status.Update(device->Sync());
|
||||
}
|
||||
|
||||
delete this;
|
||||
CHECK(done_cb != nullptr);
|
||||
runner([=]() { done_cb(status); });
|
||||
|
@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
|
||||
}
|
||||
}
|
||||
|
||||
bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
|
||||
mutex_lock lock(mu_);
|
||||
if (is_cancelled_ || is_cancelling_) {
|
||||
return false;
|
||||
} else {
|
||||
callbacks_.erase(token);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
CancellationManager::~CancellationManager() {
|
||||
if (!callbacks_.empty()) {
|
||||
StartCancel();
|
||||
|
@ -122,6 +122,15 @@ class CancellationManager {
|
||||
// cancellation manager.
|
||||
bool DeregisterCallback(CancellationToken token);
|
||||
|
||||
// Deregister the callback that, when registered, was associated
|
||||
// with the given cancellation token. Returns true iff the callback
|
||||
// was deregistered and will not be invoked; otherwise returns false
|
||||
// immediately, with no guarantee that the callback has completed.
|
||||
//
|
||||
// This method is guaranteed to return true if StartCancel has not been
|
||||
// called.
|
||||
bool TryDeregisterCallback(CancellationToken token);
|
||||
|
||||
private:
|
||||
bool is_cancelling_;
|
||||
std::atomic_bool is_cancelled_;
|
||||
|
@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
|
||||
delete cm;
|
||||
}
|
||||
|
||||
TEST(Cancellation, TryDeregisterWithoutCancel) {
|
||||
bool is_cancelled = false;
|
||||
CancellationManager* manager = new CancellationManager();
|
||||
auto token = manager->get_cancellation_token();
|
||||
bool registered = manager->RegisterCallback(
|
||||
token, [&is_cancelled]() { is_cancelled = true; });
|
||||
EXPECT_TRUE(registered);
|
||||
bool deregistered = manager->TryDeregisterCallback(token);
|
||||
EXPECT_TRUE(deregistered);
|
||||
delete manager;
|
||||
EXPECT_FALSE(is_cancelled);
|
||||
}
|
||||
|
||||
TEST(Cancellation, TryDeregisterAfterCancel) {
|
||||
bool is_cancelled = false;
|
||||
CancellationManager* manager = new CancellationManager();
|
||||
auto token = manager->get_cancellation_token();
|
||||
bool registered = manager->RegisterCallback(
|
||||
token, [&is_cancelled]() { is_cancelled = true; });
|
||||
EXPECT_TRUE(registered);
|
||||
manager->StartCancel();
|
||||
EXPECT_TRUE(is_cancelled);
|
||||
bool deregistered = manager->TryDeregisterCallback(token);
|
||||
EXPECT_FALSE(deregistered);
|
||||
delete manager;
|
||||
}
|
||||
|
||||
TEST(Cancellation, TryDeregisterDuringCancel) {
|
||||
Notification cancel_started, finish_callback, cancel_complete;
|
||||
CancellationManager* manager = new CancellationManager();
|
||||
auto token = manager->get_cancellation_token();
|
||||
bool registered = manager->RegisterCallback(token, [&]() {
|
||||
cancel_started.Notify();
|
||||
finish_callback.WaitForNotification();
|
||||
});
|
||||
EXPECT_TRUE(registered);
|
||||
|
||||
thread::ThreadPool w(Env::Default(), "test", 1);
|
||||
w.Schedule([&]() {
|
||||
manager->StartCancel();
|
||||
cancel_complete.Notify();
|
||||
});
|
||||
cancel_started.WaitForNotification();
|
||||
|
||||
bool deregistered = manager->TryDeregisterCallback(token);
|
||||
EXPECT_FALSE(deregistered);
|
||||
|
||||
finish_callback.Notify();
|
||||
cancel_complete.WaitForNotification();
|
||||
delete manager;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user