[tf.data service] Keep track of finished tasks ids on workers.
Workers maintain a list of finished task ids so that they can differentiate between whether a task is unrecognized vs finished. PiperOrigin-RevId: 348108372 Change-Id: Ia8a4a3c378ab8896bd8e81924c39de97c9a66dfe
This commit is contained in:
parent
c25e6c9976
commit
4c376c88fb
@ -504,6 +504,7 @@ cc_library(
|
||||
"//tensorflow/core/data:dataset_proto_cc",
|
||||
"//tensorflow/core/data:standalone",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
|
@ -191,8 +191,15 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
}
|
||||
auto it = tasks_.find(request->task_id());
|
||||
if (it == tasks_.end()) {
|
||||
response->set_end_of_sequence(true);
|
||||
return Status::OK();
|
||||
if (finished_tasks_.contains(request->task_id())) {
|
||||
VLOG(3) << "Task is already finished";
|
||||
response->set_end_of_sequence(true);
|
||||
return Status::OK();
|
||||
} else {
|
||||
// Perhaps the workers hasn't gotten the task from the dispatcher yet.
|
||||
// Return Unavailable so that the client knows to continue retrying.
|
||||
return errors::Unavailable("Task ", request->task_id(), " not found");
|
||||
}
|
||||
}
|
||||
auto& task = it->second;
|
||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||
@ -362,6 +369,7 @@ Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Deleting task " << task_id
|
||||
<< " at the request of the dispatcher";
|
||||
tasks_.erase(task_id);
|
||||
finished_tasks_.insert(task_id);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
@ -85,6 +86,8 @@ class DataServiceWorkerImpl {
|
||||
mutex mu_;
|
||||
// Information about tasks, keyed by task ids.
|
||||
absl::flat_hash_map<int64, std::unique_ptr<Task>> tasks_ TF_GUARDED_BY(mu_);
|
||||
// Ids of tasks that have finished.
|
||||
absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_);
|
||||
// Completed tasks which haven't yet been communicated to the dispatcher.
|
||||
absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
|
Loading…
Reference in New Issue
Block a user