[tf.data] This CL makes several cancellation improvements:
- it triggers cancellation upon iterator resource deletion - it adds cancellation support for the sleep transformation - it fixes cancellation support for the prefetch transformation PiperOrigin-RevId: 277825820 Change-Id: I9a81cff872388209bbf9e4a1ae86ccfa14976799
This commit is contained in:
parent
e8d9db593e
commit
6b66d924e8
@ -693,7 +693,12 @@ class DatasetBase : public core::RefCounted {
|
||||
(*iterator)->AddCleanupFunction(
|
||||
[model, prefix]() { model->RemoveNode(prefix); });
|
||||
}
|
||||
return (*iterator)->Initialize(ctx);
|
||||
Status s = (*iterator)->Initialize(ctx);
|
||||
if (!s.ok()) {
|
||||
// Reset the iterator to avoid returning an uninitialized iterator.
|
||||
iterator->reset();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
|
||||
|
||||
@ -653,8 +653,9 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
|
||||
CancellationManager cancellation_manager;
|
||||
f_opts.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
|
||||
cancellation_manager_, &cancellation_manager, &deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
cancellation_manager_,
|
||||
[cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
|
||||
@ -689,8 +690,9 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
|
||||
CancellationManager cancellation_manager;
|
||||
f_opts.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
|
||||
cancellation_manager_, &cancellation_manager, &deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
cancellation_manager_,
|
||||
[cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||
@ -725,8 +727,9 @@ Status InstantiatedCapturedFunction::RunInstantiated(
|
||||
CancellationManager cancellation_manager;
|
||||
f_opts.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
|
||||
cancellation_manager_, &cancellation_manager, &deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
cancellation_manager_,
|
||||
[cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
|
||||
@ -776,8 +779,10 @@ void InstantiatedCapturedFunction::RunAsync(
|
||||
auto cancellation_manager = absl::make_unique<CancellationManager>();
|
||||
f_opts.cancellation_manager = cancellation_manager.get();
|
||||
std::function<void()> deregister_fn;
|
||||
Status s = ConnectCancellationManagers(
|
||||
ctx->cancellation_manager(), cancellation_manager.get(), &deregister_fn);
|
||||
Status s = RegisterCancellationCallback(
|
||||
cancellation_manager_,
|
||||
[cm = cancellation_manager.get()]() { cm->StartCancel(); },
|
||||
&deregister_fn);
|
||||
if (!s.ok()) {
|
||||
done(s);
|
||||
return;
|
||||
|
||||
@ -319,18 +319,21 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConnectCancellationManagers(CancellationManager* parent,
|
||||
CancellationManager* child,
|
||||
std::function<void()>* deregister_fn) {
|
||||
if (parent) {
|
||||
CancellationToken token = parent->get_cancellation_token();
|
||||
if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> register_fn,
|
||||
std::function<void()>* deregister_fn) {
|
||||
if (cancellation_manager) {
|
||||
CancellationToken token = cancellation_manager->get_cancellation_token();
|
||||
if (!cancellation_manager->RegisterCallback(token,
|
||||
std::move(register_fn))) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
*deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
|
||||
*deregister_fn = [cancellation_manager, token]() {
|
||||
cancellation_manager->DeregisterCallback(token);
|
||||
};
|
||||
} else {
|
||||
VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
|
||||
"not be propagated to the child cancellation manager.";
|
||||
VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
|
||||
"not be registered.";
|
||||
*deregister_fn = []() {};
|
||||
}
|
||||
return Status::OK();
|
||||
|
||||
@ -119,12 +119,11 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
|
||||
SerializationContext&& serialization_ctx,
|
||||
GraphDef* graph_def);
|
||||
|
||||
// Creates a connection between "child" and "parent" cancellation managers so
|
||||
// that parent cancellations are propagated to the child, returning a function
|
||||
// that can be used to remove the connection.
|
||||
Status ConnectCancellationManagers(CancellationManager* parent,
|
||||
CancellationManager* child,
|
||||
std::function<void()>* deregister_fn);
|
||||
// Registers the given cancellation callback, returning a function that can be
|
||||
// used to deregister the callback.
|
||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
|
||||
std::function<void()> register_fn,
|
||||
std::function<void()>* deregister_fn);
|
||||
|
||||
// Returns Status::OK() if `expected` and `received` types match,
|
||||
// errors::InvalidArgument otherwise.
|
||||
|
||||
@ -386,6 +386,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -333,7 +333,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
LOCKS_EXCLUDED(*mu_) {
|
||||
// Get the next input element.
|
||||
std::vector<Tensor> input_element;
|
||||
bool end_of_input;
|
||||
bool end_of_input = false;
|
||||
Status status =
|
||||
input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
|
||||
bool return_early;
|
||||
|
||||
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
@ -96,16 +97,40 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
|
||||
explicit Iterator(const Params& params)
|
||||
: DatasetIterator<Dataset>(params) {}
|
||||
|
||||
~Iterator() override {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
}
|
||||
if (deregister_fn_) {
|
||||
deregister_fn_();
|
||||
}
|
||||
}
|
||||
|
||||
Status Initialize(IteratorContext* ctx) override {
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[this]() {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
},
|
||||
&deregister_fn_));
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
Status GetNextInternal(IteratorContext* ctx,
|
||||
std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence) override {
|
||||
mutex_lock l(mu_);
|
||||
RecordStop(ctx);
|
||||
ctx->env()->SleepForMicroseconds(dataset()->sleep_microseconds_);
|
||||
bool cancelled = mu_.AwaitWithDeadline(
|
||||
Condition(&cancelled_),
|
||||
ctx->env()->NowNanos() +
|
||||
dataset()->sleep_microseconds_ * EnvTime::kMicrosToNanos);
|
||||
RecordStart(ctx);
|
||||
if (cancelled) {
|
||||
return errors::Cancelled("Operation was cancelled");
|
||||
}
|
||||
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
|
||||
}
|
||||
|
||||
@ -125,8 +150,10 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
|
||||
return RestoreInput(ctx, reader, input_impl_);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
mutex mu_;
|
||||
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
|
||||
bool cancelled_ GUARDED_BY(mu_) = false;
|
||||
std::function<void()> deregister_fn_;
|
||||
};
|
||||
|
||||
const DatasetBase* const input_;
|
||||
|
||||
@ -78,11 +78,13 @@ class ToTFRecordOp : public AsyncOpKernel {
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ConnectCancellationManagers(
|
||||
ctx->cancellation_manager(),
|
||||
params.cancellation_manager, &deregister_fn),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to deregister the cancellation callback.
|
||||
done = std::bind(
|
||||
|
||||
@ -77,9 +77,10 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = &captured_state->cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
|
||||
params.cancellation_manager,
|
||||
&deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
return captured_state->iterator->GetNext(IteratorContext(std::move(params)),
|
||||
out_tensors, end_of_sequence);
|
||||
@ -122,9 +123,10 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = &captured_state->cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
|
||||
params.cancellation_manager,
|
||||
&deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
return captured_state->iterator->Restore(&iter_ctx, reader);
|
||||
@ -154,9 +156,10 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = &new_state->cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
|
||||
params.cancellation_manager,
|
||||
&deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
{
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
|
||||
@ -456,11 +459,13 @@ class ToSingleElementOp : public AsyncOpKernel {
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ConnectCancellationManagers(
|
||||
ctx->cancellation_manager(),
|
||||
params.cancellation_manager, &deregister_fn),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to deregister the cancellation callback.
|
||||
done = std::bind(
|
||||
@ -578,11 +583,13 @@ class ReduceDatasetOp : public AsyncOpKernel {
|
||||
CancellationManager cancellation_manager;
|
||||
params.cancellation_manager = &cancellation_manager;
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK_ASYNC(ctx,
|
||||
ConnectCancellationManagers(
|
||||
ctx->cancellation_manager(),
|
||||
params.cancellation_manager, &deregister_fn),
|
||||
done);
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx,
|
||||
RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn),
|
||||
done);
|
||||
|
||||
// Update the `done` callback to deregister the cancellation callback.
|
||||
done = std::bind(
|
||||
|
||||
@ -73,6 +73,8 @@ class IteratorResource : public ResourceBase {
|
||||
function_handle_cache(absl::make_unique<FunctionHandleCache>(flr)),
|
||||
iterator(std::move(iterator)) {}
|
||||
|
||||
~State() { cancellation_manager.StartCancel(); }
|
||||
|
||||
std::shared_ptr<FunctionLibraryDefinition> flib_def;
|
||||
FunctionLibraryRuntime* flr = nullptr; // not owned.
|
||||
std::shared_ptr<ProcessFunctionLibraryRuntime> pflr;
|
||||
|
||||
@ -110,9 +110,10 @@ class MultiDeviceIterator : public ResourceBase {
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
params.cancellation_manager = &cancellation_manager_;
|
||||
std::function<void()> deregister_fn;
|
||||
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
|
||||
params.cancellation_manager,
|
||||
&deregister_fn));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
MultiDeviceIteratorCallback callback_new = std::bind(
|
||||
[](const HostBufferElement& elem, MultiDeviceIteratorCallback callback,
|
||||
@ -563,9 +564,11 @@ class MultiDeviceIteratorInitOp : public OpKernel {
|
||||
params.resource_mgr = resource->resource_mgr();
|
||||
params.cancellation_manager = resource->cancellation_manager();
|
||||
std::function<void()> deregister_fn;
|
||||
OP_REQUIRES_OK(ctx, ConnectCancellationManagers(ctx->cancellation_manager(),
|
||||
params.cancellation_manager,
|
||||
&deregister_fn));
|
||||
OP_REQUIRES_OK(
|
||||
ctx, RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[cm = params.cancellation_manager]() { cm->StartCancel(); },
|
||||
&deregister_fn));
|
||||
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
|
||||
|
||||
IteratorContext iter_ctx(std::move(params));
|
||||
|
||||
@ -121,10 +121,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
~Iterator() override {
|
||||
mutex_lock l(*mu_);
|
||||
cancellation_manager_.StartCancel();
|
||||
cond_var_->notify_all();
|
||||
deregister_fn_();
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
cancelled_ = true;
|
||||
cond_var_->notify_all();
|
||||
}
|
||||
if (deregister_fn_) {
|
||||
deregister_fn_();
|
||||
}
|
||||
}
|
||||
|
||||
string BuildTraceMeName() override {
|
||||
@ -148,9 +152,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
if (buffer_size_->value == model::kAutotune) {
|
||||
buffer_size_->value = 0;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
ConnectCancellationManagers(ctx->cancellation_manager(),
|
||||
&cancellation_manager_, &deregister_fn_));
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[this]() {
|
||||
mutex_lock l(*mu_);
|
||||
cancelled_ = true;
|
||||
cond_var_->notify_all();
|
||||
},
|
||||
&deregister_fn_));
|
||||
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
|
||||
}
|
||||
|
||||
@ -164,8 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
// Wait until the next element in the buffer has been
|
||||
// produced, or we are shutting down.
|
||||
if (legacy_autotune_) {
|
||||
while (!cancellation_manager_.IsCancelled() && buffer_.empty() &&
|
||||
!prefetch_thread_finished_ &&
|
||||
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
|
||||
auto_tuner_.buffer_limit() != 0) {
|
||||
auto_tuner_.RecordEmpty();
|
||||
buffer_size_->value = auto_tuner_.buffer_limit();
|
||||
@ -174,15 +182,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
RecordStart(ctx);
|
||||
}
|
||||
} else {
|
||||
while (!cancellation_manager_.IsCancelled() && buffer_.empty() &&
|
||||
!prefetch_thread_finished_ && buffer_size_->value != 0) {
|
||||
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
|
||||
buffer_size_->value != 0) {
|
||||
RecordStop(ctx);
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx);
|
||||
}
|
||||
}
|
||||
|
||||
if (cancellation_manager_.IsCancelled()) {
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled(
|
||||
"PrefetchDatasetOp::Dataset::Iterator::GetNext");
|
||||
}
|
||||
@ -377,14 +385,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
// 1. Wait for a slot in the buffer.
|
||||
{
|
||||
mutex_lock l(*mu_);
|
||||
while (!cancellation_manager_.IsCancelled() &&
|
||||
buffer_.size() >= buffer_limit()) {
|
||||
while (!cancelled_ && buffer_.size() >= buffer_limit()) {
|
||||
RecordStop(ctx.get());
|
||||
cond_var_->wait(l);
|
||||
RecordStart(ctx.get());
|
||||
}
|
||||
|
||||
if (cancellation_manager_.IsCancelled()) {
|
||||
if (cancelled_) {
|
||||
prefetch_thread_finished_ = true;
|
||||
cond_var_->notify_all();
|
||||
return;
|
||||
@ -486,8 +493,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
// If legacy_autotune_ is false, identifies the maximum size of the buffer.
|
||||
const std::shared_ptr<model::SharedState> buffer_size_;
|
||||
|
||||
CancellationManager cancellation_manager_;
|
||||
std::function<void()> deregister_fn_;
|
||||
};
|
||||
const DatasetBase* const input_;
|
||||
|
||||
@ -19,17 +19,19 @@ from __future__ import print_function
|
||||
|
||||
import time
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import sleep
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SleepTest(test_base.DatasetTestBase):
|
||||
class SleepTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testSleep(self):
|
||||
self.skipTest("b/123597912")
|
||||
sleep_microseconds = 100
|
||||
@ -44,6 +46,37 @@ class SleepTest(test_base.DatasetTestBase):
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
self.evaluate(next_element())
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||
def testSleepCancellation(self):
|
||||
sleep_microseconds = int(1e6) * 1000
|
||||
ds = dataset_ops.Dataset.range(1)
|
||||
ds = ds.apply(sleep.sleep(sleep_microseconds))
|
||||
ds = ds.prefetch(1)
|
||||
get_next = self.getNext(ds, requires_initialization=True)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
thread = self.checkedThread(self.assert_op_cancelled, args=(get_next(),))
|
||||
thread.start()
|
||||
time.sleep(0.2)
|
||||
sess.close()
|
||||
thread.join()
|
||||
|
||||
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
|
||||
def testSleepBackgroundCancellation(self):
|
||||
ds = dataset_ops.Dataset.range(1)
|
||||
|
||||
sleep_microseconds = int(1e6) * 1000
|
||||
ds_sleep = dataset_ops.Dataset.range(1)
|
||||
ds_sleep = ds.apply(sleep.sleep(sleep_microseconds))
|
||||
|
||||
ds = ds.concatenate(ds_sleep)
|
||||
ds = ds.prefetch(1)
|
||||
|
||||
get_next = self.getNext(ds, requires_initialization=True)
|
||||
|
||||
with self.cached_session():
|
||||
self.assertEqual(self.evaluate(get_next()), 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user