[tf.data] This CL makes several cancellation improvements:

- it triggers cancellation upon iterator resource deletion
- it adds cancellation support for the sleep transformation
- it fixes cancellation support for the prefetch transformation

PiperOrigin-RevId: 277825820
Change-Id: I9a81cff872388209bbf9e4a1ae86ccfa14976799
This commit is contained in:
Jiri Simsa 2019-10-31 16:32:14 -07:00 committed by TensorFlower Gardener
parent e8d9db593e
commit 6b66d924e8
13 changed files with 170 additions and 78 deletions

View File

@ -693,7 +693,12 @@ class DatasetBase : public core::RefCounted {
(*iterator)->AddCleanupFunction(
[model, prefix]() { model->RemoveNode(prefix); });
}
return (*iterator)->Initialize(ctx);
Status s = (*iterator)->Initialize(ctx);
if (!s.ok()) {
// Reset the iterator to avoid returning an uninitialized iterator.
iterator->reset();
}
return s;
}
Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,

View File

@ -653,8 +653,9 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
cancellation_manager_, &cancellation_manager, &deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
cancellation_manager_,
[cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
@ -689,8 +690,9 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
cancellation_manager_, &cancellation_manager, &deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
cancellation_manager_,
[cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
@ -725,8 +727,9 @@ Status InstantiatedCapturedFunction::RunInstantiated(
CancellationManager cancellation_manager;
f_opts.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(
cancellation_manager_, &cancellation_manager, &deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
cancellation_manager_,
[cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
@ -776,8 +779,10 @@ void InstantiatedCapturedFunction::RunAsync(
auto cancellation_manager = absl::make_unique<CancellationManager>();
f_opts.cancellation_manager = cancellation_manager.get();
std::function<void()> deregister_fn;
Status s = ConnectCancellationManagers(
ctx->cancellation_manager(), cancellation_manager.get(), &deregister_fn);
Status s = RegisterCancellationCallback(
cancellation_manager_,
[cm = cancellation_manager.get()]() { cm->StartCancel(); },
&deregister_fn);
if (!s.ok()) {
done(s);
return;

View File

@ -319,18 +319,21 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
return Status::OK();
}
Status ConnectCancellationManagers(CancellationManager* parent,
CancellationManager* child,
std::function<void()>* deregister_fn) {
if (parent) {
CancellationToken token = parent->get_cancellation_token();
if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> register_fn,
std::function<void()>* deregister_fn) {
if (cancellation_manager) {
CancellationToken token = cancellation_manager->get_cancellation_token();
if (!cancellation_manager->RegisterCallback(token,
std::move(register_fn))) {
return errors::Cancelled("Operation was cancelled");
}
*deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
*deregister_fn = [cancellation_manager, token]() {
cancellation_manager->DeregisterCallback(token);
};
} else {
VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
"not be propagated to the child cancellation manager.";
VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
"not be registered.";
*deregister_fn = []() {};
}
return Status::OK();

View File

@ -119,12 +119,11 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
SerializationContext&& serialization_ctx,
GraphDef* graph_def);
// Creates a connection between "child" and "parent" cancellation managers so
// that parent cancellations are propagated to the child, returning a function
// that can be used to remove the connection.
Status ConnectCancellationManagers(CancellationManager* parent,
CancellationManager* child,
std::function<void()>* deregister_fn);
// Registers the given cancellation callback, returning a function that can be
// used to deregister the callback.
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
std::function<void()> register_fn,
std::function<void()>* deregister_fn);
// Returns Status::OK() if `expected` and `received` types match,
// errors::InvalidArgument otherwise.

View File

@ -386,6 +386,7 @@ tf_kernel_library(
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core/kernels/data:dataset_utils",
],
)

View File

@ -333,7 +333,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
bool end_of_input = false;
Status status =
input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
bool return_early;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
namespace tensorflow {
namespace data {
@ -96,16 +97,40 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
~Iterator() override {
{
mutex_lock l(mu_);
cancelled_ = true;
}
if (deregister_fn_) {
deregister_fn_();
}
}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() {
mutex_lock l(mu_);
cancelled_ = true;
},
&deregister_fn_));
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
RecordStop(ctx);
ctx->env()->SleepForMicroseconds(dataset()->sleep_microseconds_);
bool cancelled = mu_.AwaitWithDeadline(
Condition(&cancelled_),
ctx->env()->NowNanos() +
dataset()->sleep_microseconds_ * EnvTime::kMicrosToNanos);
RecordStart(ctx);
if (cancelled) {
return errors::Cancelled("Operation was cancelled");
}
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@ -125,8 +150,10 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
return RestoreInput(ctx, reader, input_impl_);
}
private:
std::unique_ptr<IteratorBase> input_impl_;
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
bool cancelled_ GUARDED_BY(mu_) = false;
std::function<void()> deregister_fn_;
};
const DatasetBase* const input_;

View File

@ -78,11 +78,13 @@ class ToTFRecordOp : public AsyncOpKernel {
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(ctx,
ConnectCancellationManagers(
ctx->cancellation_manager(),
params.cancellation_manager, &deregister_fn),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn),
done);
// Update the `done` callback to deregister the cancellation callback.
done = std::bind(

View File

@ -77,9 +77,10 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &captured_state->cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
params.cancellation_manager,
&deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
return captured_state->iterator->GetNext(IteratorContext(std::move(params)),
out_tensors, end_of_sequence);
@ -122,9 +123,10 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &captured_state->cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
params.cancellation_manager,
&deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));
return captured_state->iterator->Restore(&iter_ctx, reader);
@ -154,9 +156,10 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &new_state->cancellation_manager;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
params.cancellation_manager,
&deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
{
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
@ -456,11 +459,13 @@ class ToSingleElementOp : public AsyncOpKernel {
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(ctx,
ConnectCancellationManagers(
ctx->cancellation_manager(),
params.cancellation_manager, &deregister_fn),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn),
done);
// Update the `done` callback to deregister the cancellation callback.
done = std::bind(
@ -578,11 +583,13 @@ class ReduceDatasetOp : public AsyncOpKernel {
CancellationManager cancellation_manager;
params.cancellation_manager = &cancellation_manager;
std::function<void()> deregister_fn;
OP_REQUIRES_OK_ASYNC(ctx,
ConnectCancellationManagers(
ctx->cancellation_manager(),
params.cancellation_manager, &deregister_fn),
done);
OP_REQUIRES_OK_ASYNC(
ctx,
RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn),
done);
// Update the `done` callback to deregister the cancellation callback.
done = std::bind(

View File

@ -73,6 +73,8 @@ class IteratorResource : public ResourceBase {
function_handle_cache(absl::make_unique<FunctionHandleCache>(flr)),
iterator(std::move(iterator)) {}
~State() { cancellation_manager.StartCancel(); }
std::shared_ptr<FunctionLibraryDefinition> flib_def;
FunctionLibraryRuntime* flr = nullptr; // not owned.
std::shared_ptr<ProcessFunctionLibraryRuntime> pflr;

View File

@ -110,9 +110,10 @@ class MultiDeviceIterator : public ResourceBase {
params.thread_pool = &unbounded_thread_pool_;
params.cancellation_manager = &cancellation_manager_;
std::function<void()> deregister_fn;
TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
params.cancellation_manager,
&deregister_fn));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
IteratorContext iter_ctx(std::move(params));
MultiDeviceIteratorCallback callback_new = std::bind(
[](const HostBufferElement& elem, MultiDeviceIteratorCallback callback,
@ -563,9 +564,11 @@ class MultiDeviceIteratorInitOp : public OpKernel {
params.resource_mgr = resource->resource_mgr();
params.cancellation_manager = resource->cancellation_manager();
std::function<void()> deregister_fn;
OP_REQUIRES_OK(ctx, ConnectCancellationManagers(ctx->cancellation_manager(),
params.cancellation_manager,
&deregister_fn));
OP_REQUIRES_OK(
ctx, RegisterCancellationCallback(
ctx->cancellation_manager(),
[cm = params.cancellation_manager]() { cm->StartCancel(); },
&deregister_fn));
auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
IteratorContext iter_ctx(std::move(params));

View File

@ -121,10 +121,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
}
~Iterator() override {
mutex_lock l(*mu_);
cancellation_manager_.StartCancel();
cond_var_->notify_all();
deregister_fn_();
{
mutex_lock l(*mu_);
cancelled_ = true;
cond_var_->notify_all();
}
if (deregister_fn_) {
deregister_fn_();
}
}
string BuildTraceMeName() override {
@ -148,9 +152,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
if (buffer_size_->value == model::kAutotune) {
buffer_size_->value = 0;
}
TF_RETURN_IF_ERROR(
ConnectCancellationManagers(ctx->cancellation_manager(),
&cancellation_manager_, &deregister_fn_));
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() {
mutex_lock l(*mu_);
cancelled_ = true;
cond_var_->notify_all();
},
&deregister_fn_));
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
@ -164,8 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// Wait until the next element in the buffer has been
// produced, or we are shutting down.
if (legacy_autotune_) {
while (!cancellation_manager_.IsCancelled() && buffer_.empty() &&
!prefetch_thread_finished_ &&
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
buffer_size_->value = auto_tuner_.buffer_limit();
@ -174,15 +182,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
RecordStart(ctx);
}
} else {
while (!cancellation_manager_.IsCancelled() && buffer_.empty() &&
!prefetch_thread_finished_ && buffer_size_->value != 0) {
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
buffer_size_->value != 0) {
RecordStop(ctx);
cond_var_->wait(l);
RecordStart(ctx);
}
}
if (cancellation_manager_.IsCancelled()) {
if (cancelled_) {
return errors::Cancelled(
"PrefetchDatasetOp::Dataset::Iterator::GetNext");
}
@ -377,14 +385,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// 1. Wait for a slot in the buffer.
{
mutex_lock l(*mu_);
while (!cancellation_manager_.IsCancelled() &&
buffer_.size() >= buffer_limit()) {
while (!cancelled_ && buffer_.size() >= buffer_limit()) {
RecordStop(ctx.get());
cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancellation_manager_.IsCancelled()) {
if (cancelled_) {
prefetch_thread_finished_ = true;
cond_var_->notify_all();
return;
@ -486,8 +493,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
// If legacy_autotune_ is false, identifies the maximum size of the buffer.
const std::shared_ptr<model::SharedState> buffer_size_;
CancellationManager cancellation_manager_;
std::function<void()> deregister_fn_;
};
const DatasetBase* const input_;

View File

@ -19,17 +19,19 @@ from __future__ import print_function
import time
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import sleep
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class SleepTest(test_base.DatasetTestBase):
class SleepTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testSleep(self):
self.skipTest("b/123597912")
sleep_microseconds = 100
@ -44,6 +46,37 @@ class SleepTest(test_base.DatasetTestBase):
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testSleepCancellation(self):
sleep_microseconds = int(1e6) * 1000
ds = dataset_ops.Dataset.range(1)
ds = ds.apply(sleep.sleep(sleep_microseconds))
ds = ds.prefetch(1)
get_next = self.getNext(ds, requires_initialization=True)
with self.cached_session() as sess:
thread = self.checkedThread(self.assert_op_cancelled, args=(get_next(),))
thread.start()
time.sleep(0.2)
sess.close()
thread.join()
@combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
def testSleepBackgroundCancellation(self):
ds = dataset_ops.Dataset.range(1)
sleep_microseconds = int(1e6) * 1000
ds_sleep = dataset_ops.Dataset.range(1)
ds_sleep = ds.apply(sleep.sleep(sleep_microseconds))
ds = ds.concatenate(ds_sleep)
ds = ds.prefetch(1)
get_next = self.getNext(ds, requires_initialization=True)
with self.cached_session():
self.assertEqual(self.evaluate(get_next()), 0)
if __name__ == "__main__":
test.main()