[tf.data] This CL makes several cancellation improvements:
- it triggers cancellation upon iterator resource deletion - it adds cancellation support for the sleep transformation - it fixes cancellation support for the prefetch transformation PiperOrigin-RevId: 277825820 Change-Id: I9a81cff872388209bbf9e4a1ae86ccfa14976799
This commit is contained in:
		
							parent
							
								
									e8d9db593e
								
							
						
					
					
						commit
						6b66d924e8
					
				@ -693,7 +693,12 @@ class DatasetBase : public core::RefCounted {
 | 
			
		||||
      (*iterator)->AddCleanupFunction(
 | 
			
		||||
          [model, prefix]() { model->RemoveNode(prefix); });
 | 
			
		||||
    }
 | 
			
		||||
    return (*iterator)->Initialize(ctx);
 | 
			
		||||
    Status s = (*iterator)->Initialize(ctx);
 | 
			
		||||
    if (!s.ok()) {
 | 
			
		||||
      // Reset the iterator to avoid returning an uninitialized iterator.
 | 
			
		||||
      iterator->reset();
 | 
			
		||||
    }
 | 
			
		||||
    return s;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
 | 
			
		||||
 | 
			
		||||
@ -653,8 +653,9 @@ Status InstantiatedCapturedFunction::Run(IteratorContext* ctx,
 | 
			
		||||
  CancellationManager cancellation_manager;
 | 
			
		||||
  f_opts.cancellation_manager = &cancellation_manager;
 | 
			
		||||
  std::function<void()> deregister_fn;
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConnectCancellationManagers(
 | 
			
		||||
      cancellation_manager_, &cancellation_manager, &deregister_fn));
 | 
			
		||||
  TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
      cancellation_manager_,
 | 
			
		||||
      [cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
 | 
			
		||||
  auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
 | 
			
		||||
  OwnedArgsCallFrame frame(std::move(args), &captured_func_->captured_inputs(),
 | 
			
		||||
@ -689,8 +690,9 @@ Status InstantiatedCapturedFunction::RunWithBorrowedArgs(
 | 
			
		||||
  CancellationManager cancellation_manager;
 | 
			
		||||
  f_opts.cancellation_manager = &cancellation_manager;
 | 
			
		||||
  std::function<void()> deregister_fn;
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConnectCancellationManagers(
 | 
			
		||||
      cancellation_manager_, &cancellation_manager, &deregister_fn));
 | 
			
		||||
  TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
      cancellation_manager_,
 | 
			
		||||
      [cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
 | 
			
		||||
  auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
 | 
			
		||||
  BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
 | 
			
		||||
@ -725,8 +727,9 @@ Status InstantiatedCapturedFunction::RunInstantiated(
 | 
			
		||||
  CancellationManager cancellation_manager;
 | 
			
		||||
  f_opts.cancellation_manager = &cancellation_manager;
 | 
			
		||||
  std::function<void()> deregister_fn;
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConnectCancellationManagers(
 | 
			
		||||
      cancellation_manager_, &cancellation_manager, &deregister_fn));
 | 
			
		||||
  TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
      cancellation_manager_,
 | 
			
		||||
      [cm = &cancellation_manager]() { cm->StartCancel(); }, &deregister_fn));
 | 
			
		||||
  auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
 | 
			
		||||
  BorrowedArgsCallFrame frame(args, &captured_func_->captured_inputs(),
 | 
			
		||||
@ -776,8 +779,10 @@ void InstantiatedCapturedFunction::RunAsync(
 | 
			
		||||
  auto cancellation_manager = absl::make_unique<CancellationManager>();
 | 
			
		||||
  f_opts.cancellation_manager = cancellation_manager.get();
 | 
			
		||||
  std::function<void()> deregister_fn;
 | 
			
		||||
  Status s = ConnectCancellationManagers(
 | 
			
		||||
      ctx->cancellation_manager(), cancellation_manager.get(), &deregister_fn);
 | 
			
		||||
  Status s = RegisterCancellationCallback(
 | 
			
		||||
      cancellation_manager_,
 | 
			
		||||
      [cm = cancellation_manager.get()]() { cm->StartCancel(); },
 | 
			
		||||
      &deregister_fn);
 | 
			
		||||
  if (!s.ok()) {
 | 
			
		||||
    done(s);
 | 
			
		||||
    return;
 | 
			
		||||
 | 
			
		||||
@ -319,18 +319,21 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Status ConnectCancellationManagers(CancellationManager* parent,
 | 
			
		||||
                                   CancellationManager* child,
 | 
			
		||||
                                   std::function<void()>* deregister_fn) {
 | 
			
		||||
  if (parent) {
 | 
			
		||||
    CancellationToken token = parent->get_cancellation_token();
 | 
			
		||||
    if (!parent->RegisterCallback(token, [child]() { child->StartCancel(); })) {
 | 
			
		||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
 | 
			
		||||
                                    std::function<void()> register_fn,
 | 
			
		||||
                                    std::function<void()>* deregister_fn) {
 | 
			
		||||
  if (cancellation_manager) {
 | 
			
		||||
    CancellationToken token = cancellation_manager->get_cancellation_token();
 | 
			
		||||
    if (!cancellation_manager->RegisterCallback(token,
 | 
			
		||||
                                                std::move(register_fn))) {
 | 
			
		||||
      return errors::Cancelled("Operation was cancelled");
 | 
			
		||||
    }
 | 
			
		||||
    *deregister_fn = [parent, token]() { parent->DeregisterCallback(token); };
 | 
			
		||||
    *deregister_fn = [cancellation_manager, token]() {
 | 
			
		||||
      cancellation_manager->DeregisterCallback(token);
 | 
			
		||||
    };
 | 
			
		||||
  } else {
 | 
			
		||||
    VLOG(1) << "Parent cancellation manager is not set. Cancellation will "
 | 
			
		||||
               "not be propagated to the child cancellation manager.";
 | 
			
		||||
    VLOG(1) << "Cancellation manager is not set. Cancellation callback will "
 | 
			
		||||
               "not be registered.";
 | 
			
		||||
    *deregister_fn = []() {};
 | 
			
		||||
  }
 | 
			
		||||
  return Status::OK();
 | 
			
		||||
 | 
			
		||||
@ -119,12 +119,11 @@ Status AsGraphDef(OpKernelContext* ctx, const DatasetBase* dataset,
 | 
			
		||||
                  SerializationContext&& serialization_ctx,
 | 
			
		||||
                  GraphDef* graph_def);
 | 
			
		||||
 | 
			
		||||
// Creates a connection between "child" and "parent" cancellation managers so
 | 
			
		||||
// that parent cancellations are propagated to the child, returning a function
 | 
			
		||||
// that can be used to remove the connection.
 | 
			
		||||
Status ConnectCancellationManagers(CancellationManager* parent,
 | 
			
		||||
                                   CancellationManager* child,
 | 
			
		||||
                                   std::function<void()>* deregister_fn);
 | 
			
		||||
// Registers the given cancellation callback, returning a function that can be
 | 
			
		||||
// used to deregister the callback.
 | 
			
		||||
Status RegisterCancellationCallback(CancellationManager* cancellation_manager,
 | 
			
		||||
                                    std::function<void()> register_fn,
 | 
			
		||||
                                    std::function<void()>* deregister_fn);
 | 
			
		||||
 | 
			
		||||
// Returns Status::OK() if `expected` and `received` types match,
 | 
			
		||||
// errors::InvalidArgument otherwise.
 | 
			
		||||
 | 
			
		||||
@ -386,6 +386,7 @@ tf_kernel_library(
 | 
			
		||||
    deps = [
 | 
			
		||||
        "//tensorflow/core:experimental_dataset_ops_op_lib",
 | 
			
		||||
        "//tensorflow/core:framework",
 | 
			
		||||
        "//tensorflow/core/kernels/data:dataset_utils",
 | 
			
		||||
    ],
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -333,7 +333,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
        LOCKS_EXCLUDED(*mu_) {
 | 
			
		||||
      // Get the next input element.
 | 
			
		||||
      std::vector<Tensor> input_element;
 | 
			
		||||
      bool end_of_input;
 | 
			
		||||
      bool end_of_input = false;
 | 
			
		||||
      Status status =
 | 
			
		||||
          input_impl_->GetNext(ctx.get(), &input_element, &end_of_input);
 | 
			
		||||
      bool return_early;
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@ limitations under the License.
 | 
			
		||||
#include "tensorflow/core/framework/dataset.h"
 | 
			
		||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
 | 
			
		||||
#include "tensorflow/core/framework/tensor.h"
 | 
			
		||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace tensorflow {
 | 
			
		||||
namespace data {
 | 
			
		||||
@ -96,16 +97,40 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
 | 
			
		||||
      explicit Iterator(const Params& params)
 | 
			
		||||
          : DatasetIterator<Dataset>(params) {}
 | 
			
		||||
 | 
			
		||||
      ~Iterator() override {
 | 
			
		||||
        {
 | 
			
		||||
          mutex_lock l(mu_);
 | 
			
		||||
          cancelled_ = true;
 | 
			
		||||
        }
 | 
			
		||||
        if (deregister_fn_) {
 | 
			
		||||
          deregister_fn_();
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Status Initialize(IteratorContext* ctx) override {
 | 
			
		||||
        TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
            ctx->cancellation_manager(),
 | 
			
		||||
            [this]() {
 | 
			
		||||
              mutex_lock l(mu_);
 | 
			
		||||
              cancelled_ = true;
 | 
			
		||||
            },
 | 
			
		||||
            &deregister_fn_));
 | 
			
		||||
        return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      Status GetNextInternal(IteratorContext* ctx,
 | 
			
		||||
                             std::vector<Tensor>* out_tensors,
 | 
			
		||||
                             bool* end_of_sequence) override {
 | 
			
		||||
        mutex_lock l(mu_);
 | 
			
		||||
        RecordStop(ctx);
 | 
			
		||||
        ctx->env()->SleepForMicroseconds(dataset()->sleep_microseconds_);
 | 
			
		||||
        bool cancelled = mu_.AwaitWithDeadline(
 | 
			
		||||
            Condition(&cancelled_),
 | 
			
		||||
            ctx->env()->NowNanos() +
 | 
			
		||||
                dataset()->sleep_microseconds_ * EnvTime::kMicrosToNanos);
 | 
			
		||||
        RecordStart(ctx);
 | 
			
		||||
        if (cancelled) {
 | 
			
		||||
          return errors::Cancelled("Operation was cancelled");
 | 
			
		||||
        }
 | 
			
		||||
        return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
@ -125,8 +150,10 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
 | 
			
		||||
        return RestoreInput(ctx, reader, input_impl_);
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
     private:
 | 
			
		||||
      std::unique_ptr<IteratorBase> input_impl_;
 | 
			
		||||
      mutex mu_;
 | 
			
		||||
      std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
 | 
			
		||||
      bool cancelled_ GUARDED_BY(mu_) = false;
 | 
			
		||||
      std::function<void()> deregister_fn_;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    const DatasetBase* const input_;
 | 
			
		||||
 | 
			
		||||
@ -78,11 +78,13 @@ class ToTFRecordOp : public AsyncOpKernel {
 | 
			
		||||
          CancellationManager cancellation_manager;
 | 
			
		||||
          params.cancellation_manager = &cancellation_manager;
 | 
			
		||||
          std::function<void()> deregister_fn;
 | 
			
		||||
          OP_REQUIRES_OK_ASYNC(ctx,
 | 
			
		||||
                               ConnectCancellationManagers(
 | 
			
		||||
                                   ctx->cancellation_manager(),
 | 
			
		||||
                                   params.cancellation_manager, &deregister_fn),
 | 
			
		||||
                               done);
 | 
			
		||||
          OP_REQUIRES_OK_ASYNC(
 | 
			
		||||
              ctx,
 | 
			
		||||
              RegisterCancellationCallback(
 | 
			
		||||
                  ctx->cancellation_manager(),
 | 
			
		||||
                  [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
                  &deregister_fn),
 | 
			
		||||
              done);
 | 
			
		||||
 | 
			
		||||
          // Update the `done` callback to deregister the cancellation callback.
 | 
			
		||||
          done = std::bind(
 | 
			
		||||
 | 
			
		||||
@ -77,9 +77,10 @@ Status IteratorResource::GetNext(OpKernelContext* ctx,
 | 
			
		||||
    params.thread_pool = &unbounded_thread_pool_;
 | 
			
		||||
    params.cancellation_manager = &captured_state->cancellation_manager;
 | 
			
		||||
    std::function<void()> deregister_fn;
 | 
			
		||||
    TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
 | 
			
		||||
                                                   params.cancellation_manager,
 | 
			
		||||
                                                   &deregister_fn));
 | 
			
		||||
    TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
        ctx->cancellation_manager(),
 | 
			
		||||
        [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
        &deregister_fn));
 | 
			
		||||
    auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
    return captured_state->iterator->GetNext(IteratorContext(std::move(params)),
 | 
			
		||||
                                             out_tensors, end_of_sequence);
 | 
			
		||||
@ -122,9 +123,10 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
 | 
			
		||||
    params.thread_pool = &unbounded_thread_pool_;
 | 
			
		||||
    params.cancellation_manager = &captured_state->cancellation_manager;
 | 
			
		||||
    std::function<void()> deregister_fn;
 | 
			
		||||
    TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
 | 
			
		||||
                                                   params.cancellation_manager,
 | 
			
		||||
                                                   &deregister_fn));
 | 
			
		||||
    TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
        ctx->cancellation_manager(),
 | 
			
		||||
        [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
        &deregister_fn));
 | 
			
		||||
    auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
    IteratorContext iter_ctx(std::move(params));
 | 
			
		||||
    return captured_state->iterator->Restore(&iter_ctx, reader);
 | 
			
		||||
@ -154,9 +156,10 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
 | 
			
		||||
  params.thread_pool = &unbounded_thread_pool_;
 | 
			
		||||
  params.cancellation_manager = &new_state->cancellation_manager;
 | 
			
		||||
  std::function<void()> deregister_fn;
 | 
			
		||||
  TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
 | 
			
		||||
                                                 params.cancellation_manager,
 | 
			
		||||
                                                 &deregister_fn));
 | 
			
		||||
  TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
      ctx->cancellation_manager(),
 | 
			
		||||
      [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
      &deregister_fn));
 | 
			
		||||
  {
 | 
			
		||||
    auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
    TF_RETURN_IF_ERROR(dataset->MakeIterator(IteratorContext(std::move(params)),
 | 
			
		||||
@ -456,11 +459,13 @@ class ToSingleElementOp : public AsyncOpKernel {
 | 
			
		||||
          CancellationManager cancellation_manager;
 | 
			
		||||
          params.cancellation_manager = &cancellation_manager;
 | 
			
		||||
          std::function<void()> deregister_fn;
 | 
			
		||||
          OP_REQUIRES_OK_ASYNC(ctx,
 | 
			
		||||
                               ConnectCancellationManagers(
 | 
			
		||||
                                   ctx->cancellation_manager(),
 | 
			
		||||
                                   params.cancellation_manager, &deregister_fn),
 | 
			
		||||
                               done);
 | 
			
		||||
          OP_REQUIRES_OK_ASYNC(
 | 
			
		||||
              ctx,
 | 
			
		||||
              RegisterCancellationCallback(
 | 
			
		||||
                  ctx->cancellation_manager(),
 | 
			
		||||
                  [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
                  &deregister_fn),
 | 
			
		||||
              done);
 | 
			
		||||
 | 
			
		||||
          // Update the `done` callback to deregister the cancellation callback.
 | 
			
		||||
          done = std::bind(
 | 
			
		||||
@ -578,11 +583,13 @@ class ReduceDatasetOp : public AsyncOpKernel {
 | 
			
		||||
          CancellationManager cancellation_manager;
 | 
			
		||||
          params.cancellation_manager = &cancellation_manager;
 | 
			
		||||
          std::function<void()> deregister_fn;
 | 
			
		||||
          OP_REQUIRES_OK_ASYNC(ctx,
 | 
			
		||||
                               ConnectCancellationManagers(
 | 
			
		||||
                                   ctx->cancellation_manager(),
 | 
			
		||||
                                   params.cancellation_manager, &deregister_fn),
 | 
			
		||||
                               done);
 | 
			
		||||
          OP_REQUIRES_OK_ASYNC(
 | 
			
		||||
              ctx,
 | 
			
		||||
              RegisterCancellationCallback(
 | 
			
		||||
                  ctx->cancellation_manager(),
 | 
			
		||||
                  [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
                  &deregister_fn),
 | 
			
		||||
              done);
 | 
			
		||||
 | 
			
		||||
          // Update the `done` callback to deregister the cancellation callback.
 | 
			
		||||
          done = std::bind(
 | 
			
		||||
 | 
			
		||||
@ -73,6 +73,8 @@ class IteratorResource : public ResourceBase {
 | 
			
		||||
          function_handle_cache(absl::make_unique<FunctionHandleCache>(flr)),
 | 
			
		||||
          iterator(std::move(iterator)) {}
 | 
			
		||||
 | 
			
		||||
    ~State() { cancellation_manager.StartCancel(); }
 | 
			
		||||
 | 
			
		||||
    std::shared_ptr<FunctionLibraryDefinition> flib_def;
 | 
			
		||||
    FunctionLibraryRuntime* flr = nullptr;  // not owned.
 | 
			
		||||
    std::shared_ptr<ProcessFunctionLibraryRuntime> pflr;
 | 
			
		||||
 | 
			
		||||
@ -110,9 +110,10 @@ class MultiDeviceIterator : public ResourceBase {
 | 
			
		||||
    params.thread_pool = &unbounded_thread_pool_;
 | 
			
		||||
    params.cancellation_manager = &cancellation_manager_;
 | 
			
		||||
    std::function<void()> deregister_fn;
 | 
			
		||||
    TF_RETURN_IF_ERROR(ConnectCancellationManagers(ctx->cancellation_manager(),
 | 
			
		||||
                                                   params.cancellation_manager,
 | 
			
		||||
                                                   &deregister_fn));
 | 
			
		||||
    TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
        ctx->cancellation_manager(),
 | 
			
		||||
        [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
        &deregister_fn));
 | 
			
		||||
    IteratorContext iter_ctx(std::move(params));
 | 
			
		||||
    MultiDeviceIteratorCallback callback_new = std::bind(
 | 
			
		||||
        [](const HostBufferElement& elem, MultiDeviceIteratorCallback callback,
 | 
			
		||||
@ -563,9 +564,11 @@ class MultiDeviceIteratorInitOp : public OpKernel {
 | 
			
		||||
    params.resource_mgr = resource->resource_mgr();
 | 
			
		||||
    params.cancellation_manager = resource->cancellation_manager();
 | 
			
		||||
    std::function<void()> deregister_fn;
 | 
			
		||||
    OP_REQUIRES_OK(ctx, ConnectCancellationManagers(ctx->cancellation_manager(),
 | 
			
		||||
                                                    params.cancellation_manager,
 | 
			
		||||
                                                    &deregister_fn));
 | 
			
		||||
    OP_REQUIRES_OK(
 | 
			
		||||
        ctx, RegisterCancellationCallback(
 | 
			
		||||
                 ctx->cancellation_manager(),
 | 
			
		||||
                 [cm = params.cancellation_manager]() { cm->StartCancel(); },
 | 
			
		||||
                 &deregister_fn));
 | 
			
		||||
    auto cleanup = gtl::MakeCleanup(std::move(deregister_fn));
 | 
			
		||||
 | 
			
		||||
    IteratorContext iter_ctx(std::move(params));
 | 
			
		||||
 | 
			
		||||
@ -121,10 +121,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ~Iterator() override {
 | 
			
		||||
      mutex_lock l(*mu_);
 | 
			
		||||
      cancellation_manager_.StartCancel();
 | 
			
		||||
      cond_var_->notify_all();
 | 
			
		||||
      deregister_fn_();
 | 
			
		||||
      {
 | 
			
		||||
        mutex_lock l(*mu_);
 | 
			
		||||
        cancelled_ = true;
 | 
			
		||||
        cond_var_->notify_all();
 | 
			
		||||
      }
 | 
			
		||||
      if (deregister_fn_) {
 | 
			
		||||
        deregister_fn_();
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    string BuildTraceMeName() override {
 | 
			
		||||
@ -148,9 +152,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
      if (buffer_size_->value == model::kAutotune) {
 | 
			
		||||
        buffer_size_->value = 0;
 | 
			
		||||
      }
 | 
			
		||||
      TF_RETURN_IF_ERROR(
 | 
			
		||||
          ConnectCancellationManagers(ctx->cancellation_manager(),
 | 
			
		||||
                                      &cancellation_manager_, &deregister_fn_));
 | 
			
		||||
      TF_RETURN_IF_ERROR(RegisterCancellationCallback(
 | 
			
		||||
          ctx->cancellation_manager(),
 | 
			
		||||
          [this]() {
 | 
			
		||||
            mutex_lock l(*mu_);
 | 
			
		||||
            cancelled_ = true;
 | 
			
		||||
            cond_var_->notify_all();
 | 
			
		||||
          },
 | 
			
		||||
          &deregister_fn_));
 | 
			
		||||
      return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@ -164,8 +173,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
        // Wait until the next element in the buffer has been
 | 
			
		||||
        // produced, or we are shutting down.
 | 
			
		||||
        if (legacy_autotune_) {
 | 
			
		||||
          while (!cancellation_manager_.IsCancelled() && buffer_.empty() &&
 | 
			
		||||
                 !prefetch_thread_finished_ &&
 | 
			
		||||
          while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
 | 
			
		||||
                 auto_tuner_.buffer_limit() != 0) {
 | 
			
		||||
            auto_tuner_.RecordEmpty();
 | 
			
		||||
            buffer_size_->value = auto_tuner_.buffer_limit();
 | 
			
		||||
@ -174,15 +182,15 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
            RecordStart(ctx);
 | 
			
		||||
          }
 | 
			
		||||
        } else {
 | 
			
		||||
          while (!cancellation_manager_.IsCancelled() && buffer_.empty() &&
 | 
			
		||||
                 !prefetch_thread_finished_ && buffer_size_->value != 0) {
 | 
			
		||||
          while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
 | 
			
		||||
                 buffer_size_->value != 0) {
 | 
			
		||||
            RecordStop(ctx);
 | 
			
		||||
            cond_var_->wait(l);
 | 
			
		||||
            RecordStart(ctx);
 | 
			
		||||
          }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (cancellation_manager_.IsCancelled()) {
 | 
			
		||||
        if (cancelled_) {
 | 
			
		||||
          return errors::Cancelled(
 | 
			
		||||
              "PrefetchDatasetOp::Dataset::Iterator::GetNext");
 | 
			
		||||
        }
 | 
			
		||||
@ -377,14 +385,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
        // 1. Wait for a slot in the buffer.
 | 
			
		||||
        {
 | 
			
		||||
          mutex_lock l(*mu_);
 | 
			
		||||
          while (!cancellation_manager_.IsCancelled() &&
 | 
			
		||||
                 buffer_.size() >= buffer_limit()) {
 | 
			
		||||
          while (!cancelled_ && buffer_.size() >= buffer_limit()) {
 | 
			
		||||
            RecordStop(ctx.get());
 | 
			
		||||
            cond_var_->wait(l);
 | 
			
		||||
            RecordStart(ctx.get());
 | 
			
		||||
          }
 | 
			
		||||
 | 
			
		||||
          if (cancellation_manager_.IsCancelled()) {
 | 
			
		||||
          if (cancelled_) {
 | 
			
		||||
            prefetch_thread_finished_ = true;
 | 
			
		||||
            cond_var_->notify_all();
 | 
			
		||||
            return;
 | 
			
		||||
@ -486,8 +493,6 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
 | 
			
		||||
 | 
			
		||||
    // If legacy_autotune_ is false, identifies the maximum size of the buffer.
 | 
			
		||||
    const std::shared_ptr<model::SharedState> buffer_size_;
 | 
			
		||||
 | 
			
		||||
    CancellationManager cancellation_manager_;
 | 
			
		||||
    std::function<void()> deregister_fn_;
 | 
			
		||||
  };
 | 
			
		||||
  const DatasetBase* const input_;
 | 
			
		||||
 | 
			
		||||
@ -19,17 +19,19 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from absl.testing import parameterized
 | 
			
		||||
 | 
			
		||||
from tensorflow.python.data.experimental.ops import sleep
 | 
			
		||||
from tensorflow.python.data.kernel_tests import test_base
 | 
			
		||||
from tensorflow.python.data.ops import dataset_ops
 | 
			
		||||
from tensorflow.python.framework import combinations
 | 
			
		||||
from tensorflow.python.framework import errors
 | 
			
		||||
from tensorflow.python.framework import test_util
 | 
			
		||||
from tensorflow.python.platform import test
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@test_util.run_all_in_graph_and_eager_modes
 | 
			
		||||
class SleepTest(test_base.DatasetTestBase):
 | 
			
		||||
class SleepTest(test_base.DatasetTestBase, parameterized.TestCase):
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(test_base.default_test_combinations())
 | 
			
		||||
  def testSleep(self):
 | 
			
		||||
    self.skipTest("b/123597912")
 | 
			
		||||
    sleep_microseconds = 100
 | 
			
		||||
@ -44,6 +46,37 @@ class SleepTest(test_base.DatasetTestBase):
 | 
			
		||||
    with self.assertRaises(errors.OutOfRangeError):
 | 
			
		||||
      self.evaluate(next_element())
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
 | 
			
		||||
  def testSleepCancellation(self):
 | 
			
		||||
    sleep_microseconds = int(1e6) * 1000
 | 
			
		||||
    ds = dataset_ops.Dataset.range(1)
 | 
			
		||||
    ds = ds.apply(sleep.sleep(sleep_microseconds))
 | 
			
		||||
    ds = ds.prefetch(1)
 | 
			
		||||
    get_next = self.getNext(ds, requires_initialization=True)
 | 
			
		||||
 | 
			
		||||
    with self.cached_session() as sess:
 | 
			
		||||
      thread = self.checkedThread(self.assert_op_cancelled, args=(get_next(),))
 | 
			
		||||
      thread.start()
 | 
			
		||||
      time.sleep(0.2)
 | 
			
		||||
      sess.close()
 | 
			
		||||
      thread.join()
 | 
			
		||||
 | 
			
		||||
  @combinations.generate(combinations.combine(tf_api_version=1, mode="graph"))
 | 
			
		||||
  def testSleepBackgroundCancellation(self):
 | 
			
		||||
    ds = dataset_ops.Dataset.range(1)
 | 
			
		||||
 | 
			
		||||
    sleep_microseconds = int(1e6) * 1000
 | 
			
		||||
    ds_sleep = dataset_ops.Dataset.range(1)
 | 
			
		||||
    ds_sleep = ds.apply(sleep.sleep(sleep_microseconds))
 | 
			
		||||
 | 
			
		||||
    ds = ds.concatenate(ds_sleep)
 | 
			
		||||
    ds = ds.prefetch(1)
 | 
			
		||||
 | 
			
		||||
    get_next = self.getNext(ds, requires_initialization=True)
 | 
			
		||||
 | 
			
		||||
    with self.cached_session():
 | 
			
		||||
      self.assertEqual(self.evaluate(get_next()), 0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
  test.main()
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user