[tf.data service] Support cancellation in DataServiceDatasetOp.
PiperOrigin-RevId: 324078337 Change-Id: Ie4bde096e2e43ba812c7ab8485270fd8e4d47126
This commit is contained in:
parent
a647794f9f
commit
a9740221ea
@ -185,20 +185,28 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
~Iterator() override {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(1) << "Destroying data service dataset iterator for job id "
|
||||
<< job_id_;
|
||||
CancelThreads();
|
||||
if (deregister_fn_) deregister_fn_();
|
||||
// Thread destructors will block until the threads finish, no need to wait
|
||||
// here.
|
||||
}
|
||||
|
||||
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
worker_thread_cv_.notify_all();
|
||||
manager_thread_cv_.notify_all();
|
||||
get_next_cv_.notify_all();
|
||||
// Thread destructors will block until the threads finish, no need to wait
|
||||
// here.
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
VLOG(3) << "Connecting to " << dataset()->address_
|
||||
<< " in data service dataset op";
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(), [this]() { CancelThreads(); },
|
||||
&deregister_fn_));
|
||||
DataServiceDispatcherClient dispatcher(dataset()->address_,
|
||||
dataset()->protocol_);
|
||||
if (dataset()->job_name_.empty()) {
|
||||
@ -531,6 +539,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
condition_variable worker_thread_cv_ TF_GUARDED_BY(mu_);
|
||||
condition_variable manager_thread_cv_ TF_GUARDED_BY(mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
// Method for deregistering the cancellation callback.
|
||||
std::function<void()> deregister_fn_;
|
||||
|
||||
int64 outstanding_requests_ TF_GUARDED_BY(mu_) = 0;
|
||||
// max_outstanding_requests controls how many elements may be held in memory
|
||||
|
||||
@ -526,6 +526,28 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
with self.assertRaisesRegex(errors.NotFoundError, "Dataset id"):
|
||||
self.evaluate(self.getNext(from_dataset_id_ds)())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCancellation(self):
|
||||
self.skipTest("b/162521601")
|
||||
sleep_microseconds = int(1e6) * 1000
|
||||
|
||||
self._dispatcher = server_lib.DispatchServer(port=0, protocol=PROTOCOL)
|
||||
self._worker = server_lib.WorkerServer(
|
||||
port=0, dispatcher_address=self._dispatcher._address, protocol=PROTOCOL)
|
||||
# Create a dataset which produces the first element quickly, and the second
|
||||
# element slowly. Fetching the first element triggers prefetching of the
|
||||
# second element, which we should be able to cancel.
|
||||
slow = dataset_ops.Dataset.range(1)
|
||||
slow = slow.apply(testing.sleep(sleep_microseconds))
|
||||
ds = dataset_ops.Dataset.range(1).concatenate(slow)
|
||||
ds = _make_distributed_dataset(
|
||||
ds, "{}://{}".format(PROTOCOL, self._dispatcher._address))
|
||||
ds = ds.prefetch(1)
|
||||
get_next = self.getNext(ds, requires_initialization=True)
|
||||
self.assertEqual(0, self.evaluate(get_next()))
|
||||
# Without properly implemented cancellation, we will hang here while trying
|
||||
# to garbage collect the dataset iterator.
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testRegisterEquivalentDatasets(self):
|
||||
ds_1 = dataset_ops.Dataset.range(10)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user