diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 19e7adf5b65..c1cb8589e11 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -504,6 +504,7 @@ cc_library( "//tensorflow/core/data:dataset_proto_cc", "//tensorflow/core/data:standalone", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/memory", tf_grpc_cc_dependency(), ], diff --git a/tensorflow/core/data/service/worker_impl.cc b/tensorflow/core/data/service/worker_impl.cc index b7253e3a5ab..2f81d0d5af0 100644 --- a/tensorflow/core/data/service/worker_impl.cc +++ b/tensorflow/core/data/service/worker_impl.cc @@ -191,8 +191,15 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, } auto it = tasks_.find(request->task_id()); if (it == tasks_.end()) { - response->set_end_of_sequence(true); - return Status::OK(); + if (finished_tasks_.contains(request->task_id())) { + VLOG(3) << "Task is already finished"; + response->set_end_of_sequence(true); + return Status::OK(); + } else { + // Perhaps the workers hasn't gotten the task from the dispatcher yet. + // Return Unavailable so that the client knows to continue retrying. + return errors::Unavailable("Task ", request->task_id(), " not found"); + } } auto& task = it->second; TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); @@ -362,6 +369,7 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) { VLOG(3) << "Deleting task " << task_id << " at the request of the dispatcher"; tasks_.erase(task_id); + finished_tasks_.insert(task_id); } return Status::OK(); } diff --git a/tensorflow/core/data/service/worker_impl.h b/tensorflow/core/data/service/worker_impl.h index 47b883d654e..80eb5b756a4 100644 --- a/tensorflow/core/data/service/worker_impl.h +++ b/tensorflow/core/data/service/worker_impl.h @@ -16,6 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" @@ -85,6 +86,8 @@ class DataServiceWorkerImpl { mutex mu_; // Information about tasks, keyed by task ids. absl::flat_hash_map<int64, std::unique_ptr<Task>> tasks_ TF_GUARDED_BY(mu_); + // Ids of tasks that have finished. + absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_); // Completed tasks which haven't yet been communicated to the dispatcher. absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_); bool cancelled_ TF_GUARDED_BY(mu_) = false;