Internal change.

PiperOrigin-RevId: 213770000
This commit is contained in:
A. Unique TensorFlower 2018-09-20 01:43:05 -07:00 committed by TensorFlower Gardener
parent da3357ecbd
commit a54310b1fa
9 changed files with 142 additions and 5 deletions

View File

@ -434,6 +434,16 @@ Status XlaDevice::MakeTensorFromProto(const TensorProto& tensor_proto,
return status;
}
void XlaDevice::SetRequiresSyncOnCompletion(bool sync_on_completion) {
mutex_lock lock(mu_);
sync_on_completion_ = sync_on_completion;
}
bool XlaDevice::RequiresSyncOnCompletion() const {
mutex_lock lock(mu_);
return sync_on_completion_;
}
XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device,
const char* jit_device) {
// Any op assigned to the device that isn't rewritten by the graph rewriter

View File

@ -151,6 +151,12 @@ class XlaDevice : public LocalDevice {
// information for GPU and TPU devices.
Status UseGpuDeviceInfo() LOCKS_EXCLUDED(mu_);
// Instructs this XlaDevice to return 'sync_on_completion' for
// RequiresSyncOnCompletion().
void SetRequiresSyncOnCompletion(bool sync_on_completion) LOCKS_EXCLUDED(mu_);
bool RequiresSyncOnCompletion() const override LOCKS_EXCLUDED(mu_);
private:
xla::LocalClient* client() const;
Allocator* GetAllocatorLocked(AllocatorAttributes attr)
@ -165,7 +171,7 @@ class XlaDevice : public LocalDevice {
static Status GetMetadataFromDevice(DeviceBase* device,
const XlaDevice::Metadata** metadata);
mutex mu_;
mutable mutex mu_;
// The metadata of this XlaDevice.
const Metadata xla_metadata_;
// Which hardware device in the client's platform this XlaDevice controls.
@ -207,6 +213,10 @@ class XlaDevice : public LocalDevice {
// Thread pool used for running closures
std::unique_ptr<thread::ThreadPool> thread_pool_;
// True if the device requires XlaDevice::Sync to be called on completion
// regardless of status.
bool sync_on_completion_ GUARDED_BY(mu_) = false;
};
// Builds OpKernel registrations on 'device' for the JIT operators

View File

@ -28,8 +28,14 @@ StreamPool::Ptr StreamPool::BorrowStream(se::StreamExecutor* executor) {
// Re-use an existing stream from the pool.
stream = std::move(streams_.back());
streams_.pop_back();
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool reusing existing stream";
if (stream->ok()) {
VLOG(1) << stream->DebugStreamPointers()
<< " StreamPool reusing existing stream";
} else {
VLOG(1) << stream->DebugStreamPointers()
<< " stream was not ok, StreamPool deleting";
stream = nullptr;
}
}
}

View File

@ -132,5 +132,39 @@ TEST_F(StreamPoolTest, BadStreamDiscarded) {
EXPECT_EQ(stream2_ptr, stream3_ptr);
}
TEST_F(StreamPoolTest, BadStreamAfterReturnDiscarded) {
std::unique_ptr<se::StreamExecutor> executor = NewStreamExecutor();
StreamPool pool;
// Borrow a stream.
StreamPool::Ptr stream1 = pool.BorrowStream(executor.get());
EXPECT_TRUE(stream1->ok());
// Return the stream, but hold a handle to it.
se::Stream* stream1_ptr = stream1.get();
stream1 = nullptr;
// Now stream1 is back in the pool, force an error on the stream. Here we call
// a method that requires DNN support, which we know the Host platform doesn't
// support.
stream1_ptr->ThenDepthConcatenate({}, {}, nullptr);
EXPECT_FALSE(stream1_ptr->ok());
// Borrow stream2.
StreamPool::Ptr stream2 = pool.BorrowStream(executor.get());
EXPECT_TRUE(stream2->ok());
// The underlying streams should be different. They would have been
// the same, but since we forced an error on stream1, it cannot be
// put back into the pool. Sadly we can't just check:
// EXPECT_NE(stream1_ptr, stream2_ptr);
//
// The above should hold logically, but it may fail if the new
// stream instance allocated for stream2 happens to reside in the
// same memory address as stream1, which has been deleted.
//
// The check that stream2->ok() serves as a good-enough check.
}
} // namespace
} // namespace xla

View File

@ -106,6 +106,10 @@ class Device : public DeviceBase {
// at completion.
virtual Status Sync() = 0;
// Override this to return true for devices that require a Sync() call before
// session completion.
virtual bool RequiresSyncOnCompletion() const { return false; }
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable

View File

@ -2301,13 +2301,15 @@ void ExecutorState::Finish() {
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
if (sync_on_finish_ && status.ok()) {
Device* device = impl_->params_.device;
if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
status = impl_->params_.device->Sync();
status.Update(device->Sync());
}
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });

View File

@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
}
}
bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
mutex_lock lock(mu_);
if (is_cancelled_ || is_cancelling_) {
return false;
} else {
callbacks_.erase(token);
return true;
}
}
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();

View File

@ -122,6 +122,15 @@ class CancellationManager {
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
// Deregister the callback that, when registered, was associated
// with the given cancellation token. Returns true iff the callback
// was deregistered and will not be invoked; otherwise returns false
// immediately, with no guarantee that the callback has completed.
//
// This method is guaranteed to return true if StartCancel has not been
// called.
bool TryDeregisterCallback(CancellationToken token);
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;

View File

@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
TEST(Cancellation, TryDeregisterWithoutCancel) {
bool is_cancelled = false;
CancellationManager* manager = new CancellationManager();
auto token = manager->get_cancellation_token();
bool registered = manager->RegisterCallback(
token, [&is_cancelled]() { is_cancelled = true; });
EXPECT_TRUE(registered);
bool deregistered = manager->TryDeregisterCallback(token);
EXPECT_TRUE(deregistered);
delete manager;
EXPECT_FALSE(is_cancelled);
}
TEST(Cancellation, TryDeregisterAfterCancel) {
bool is_cancelled = false;
CancellationManager* manager = new CancellationManager();
auto token = manager->get_cancellation_token();
bool registered = manager->RegisterCallback(
token, [&is_cancelled]() { is_cancelled = true; });
EXPECT_TRUE(registered);
manager->StartCancel();
EXPECT_TRUE(is_cancelled);
bool deregistered = manager->TryDeregisterCallback(token);
EXPECT_FALSE(deregistered);
delete manager;
}
TEST(Cancellation, TryDeregisterDuringCancel) {
Notification cancel_started, finish_callback, cancel_complete;
CancellationManager* manager = new CancellationManager();
auto token = manager->get_cancellation_token();
bool registered = manager->RegisterCallback(token, [&]() {
cancel_started.Notify();
finish_callback.WaitForNotification();
});
EXPECT_TRUE(registered);
thread::ThreadPool w(Env::Default(), "test", 1);
w.Schedule([&]() {
manager->StartCancel();
cancel_complete.Notify();
});
cancel_started.WaitForNotification();
bool deregistered = manager->TryDeregisterCallback(token);
EXPECT_FALSE(deregistered);
finish_callback.Notify();
cancel_complete.WaitForNotification();
delete manager;
}
} // namespace tensorflow