[tf.data service] Support cancellation in DataServiceDatasetOp.

PiperOrigin-RevId: 324078337
Change-Id: Ie4bde096e2e43ba812c7ab8485270fd8e4d47126
This commit is contained in:
Andrew Audibert 2020-07-30 14:04:25 -07:00 committed by TensorFlower Gardener
parent a647794f9f
commit a9740221ea
2 changed files with 35 additions and 3 deletions

View File

@ -185,20 +185,28 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}
~Iterator() override {
mutex_lock l(mu_);
VLOG(1) << "Destroying data service dataset iterator for job id "
<< job_id_;
CancelThreads();
if (deregister_fn_) deregister_fn_();
// Thread destructors will block until the threads finish, no need to wait
// here.
}
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
cancelled_ = true;
worker_thread_cv_.notify_all();
manager_thread_cv_.notify_all();
get_next_cv_.notify_all();
// Thread destructors will block until the threads finish, no need to wait
// here.
}
Status Initialize(IteratorContext* ctx) override {
VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op";
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(), [this]() { CancelThreads(); },
&deregister_fn_));
DataServiceDispatcherClient dispatcher(dataset()->address_,
dataset()->protocol_);
if (dataset()->job_name_.empty()) {
@ -531,6 +539,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
bool cancelled_ TF_GUARDED_BY(mu_) = false;
// Method for deregistering the cancellation callback.
std::function<void()> deregister_fn_;
int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
// max_outstanding_requests controls how many elements may be held in memory

View File

@ -526,6 +526,28 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
self.evaluate(self.getNext(from_dataset_id_ds)())
@combinations.generate(test_base.default_test_combinations())
def testCancellation(self):
self.skipTest("b/162521601")
sleep_microseconds = int(1e6) * 1000
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
self._worker = server_lib.WorkerServer(
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
# Create a dataset which produces the first element quickly, and the second
# element slowly. Fetching the first element triggers prefetching of the
# second element, which we should be able to cancel.
slow = dataset_ops.Dataset.range(1)
slow = slow.apply(testing.sleep(sleep_microseconds))
ds = dataset_ops.Dataset.range(1).concatenate(slow)
ds = _make_distributed_dataset(
ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
ds = ds.prefetch(1)
get_next = self.getNext(ds, requires_initialization=True)
self.assertEqual(0, self.evaluate(get_next()))
# Without properly implemented cancellation, we will hang here while trying
# to garbage collect the dataset iterator.
@combinations.generate(test_base.eager_only_combinations())
def testRegisterEquivalentDatasets(self):
ds_1 = dataset_ops.Dataset.range(10)