[tf.data service] Server-side implementation for round-robin data consumption.

PiperOrigin-RevId: 345743063
Change-Id: Id2e19d048cf49b37185dfd558d3ab010df23f9bc
This commit is contained in:
Andrew Audibert 2020-12-04 13:18:24 -08:00 committed by TensorFlower Gardener
parent 6b4ba7fb16
commit 80aa374b54
7 changed files with 529 additions and 13 deletions

View File

@ -373,6 +373,31 @@ cc_library(
], ],
) )
cc_library(
name = "task_runner",
srcs = ["task_runner.cc"],
hdrs = ["task_runner.h"],
deps = [
":common_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core/data:standalone",
],
)
tf_cc_test(
name = "task_runner_test",
srcs = ["task_runner_test.cc"],
deps = [
":task_runner",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/data:standalone",
"@com_google_absl//absl/memory",
],
)
cc_library( cc_library(
name = "test_cluster", name = "test_cluster",
testonly = True, testonly = True,
@ -465,6 +490,7 @@ cc_library(
":dispatcher_proto_cc", ":dispatcher_proto_cc",
":grpc_util", ":grpc_util",
":split_provider", ":split_provider",
":task_runner",
":utils", ":utils",
":worker_proto_cc", ":worker_proto_cc",
"//tensorflow/c:c_api_internal", "//tensorflow/c:c_api_internal",

View File

@ -0,0 +1,129 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/task_runner.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace data {
StandaloneTaskIterator::StandaloneTaskIterator(
std::unique_ptr<standalone::Dataset> dataset,
std::unique_ptr<standalone::Iterator> iterator)
: dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
bool& end_of_sequence) {
return iterator_->GetNext(&element, &end_of_sequence);
}
Status TaskRunner::Create(const TaskDef& task_def,
std::unique_ptr<TaskIterator> iterator,
std::unique_ptr<TaskRunner>& out) {
if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
task_def.num_consumers());
} else {
out =
absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
}
return Status::OK();
}
FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
std::unique_ptr<TaskIterator> iterator)
: iterator_(std::move(iterator)) {}
Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
std::vector<Tensor>& element,
bool& end_of_task) {
return iterator_->GetNext(element, end_of_task);
}
RoundRobinTaskRunner::RoundRobinTaskRunner(
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
: num_consumers_(num_consumers),
iterator_(std::move(iterator)),
buffer_(num_consumers_) {}
Status RoundRobinTaskRunner::GetNext(const Request& request,
std::vector<Tensor>& element,
bool& end_of_task) {
if (request.consumer_index < 0 || request.round_index < 0) {
return errors::FailedPrecondition(
"RoundRobinTaskRunner needs to know the consumer index and element "
"index of each request.");
}
if (request.consumer_index >= num_consumers_) {
return errors::FailedPrecondition(
"Requesting data for consumer index ", request.consumer_index,
", but the task is configured for only ", num_consumers_, " consumers");
}
mutex_lock l(mu_);
absl::flat_hash_set<int64>& round = requests_[request.round_index];
first_round_ = std::min(first_round_, request.round_index);
round.insert(request.consumer_index);
if (current_round_ < request.round_index && round.size() == num_consumers_) {
// This was the last request to arrive, time to start a new round.
TF_RETURN_IF_ERROR(FillBuffer());
current_round_ = request.round_index;
new_round_cv_.notify_all();
}
if (current_round_ < 0 &&
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
num_consumers_) {
// Indicates that we need a partial round to get consumers back in sync.
TF_RETURN_IF_ERROR(FillBuffer());
current_round_ = first_round_;
new_round_cv_.notify_all();
}
while (current_round_ < request.round_index) {
new_round_cv_.wait(l);
}
Result& result = buffer_[request.consumer_index];
end_of_task = result.end_of_task;
if (!end_of_task) {
element = std::move(result.element);
}
return Status::OK();
}
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (int i = 0; i < num_consumers_; ++i) {
Result& result = buffer_[i];
result.element.clear();
TF_RETURN_IF_ERROR(iterator_->GetNext(result.element, result.end_of_task));
if (buffer_[i].end_of_task && !buffer_[0].end_of_task) {
std::vector<Tensor>& first_element = buffer_[0].element;
// Pad out the round with empty elements.
buffer_[i].element.clear();
for (int c = 0; c < first_element.size(); ++c) {
TensorShape shape = first_element[c].shape();
if (shape.dims() > 0) {
shape.set_dim(0, 0);
}
buffer_[i].element.push_back(Tensor(first_element[c].dtype(), shape));
}
buffer_[i].end_of_task = false;
}
}
return Status::OK();
}
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,143 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace data {
// Iterator over a task's elements.
class TaskIterator {
public:
virtual ~TaskIterator() = default;
// If the iterator is not yet exhausted, `GetNext` stores the next element in
// `element` and sets `end_of_sequence` to `false`. Otherwise, sets
// `end_of_sequence to `true`.
virtual Status GetNext(std::vector<Tensor>& element,
bool& end_of_sequence) = 0;
};
// Implementation of TaskIterator wrapping a standalone iterator.
class StandaloneTaskIterator : public TaskIterator {
public:
// `dataset` should be the dataset that created `iterator`.
// StandaloneTaskIterator takes ownership of the dataset to ensures it
// lives as long as `iterator`.
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,
std::unique_ptr<standalone::Iterator> iterator);
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override;
private:
std::unique_ptr<standalone::Dataset> dataset_;
std::unique_ptr<standalone::Iterator> iterator_;
};
// Interface for providing elements to task consumers.
class TaskRunner {
public:
struct Request {
// Optional consumer index indicating which consumer is making the request.
// Only needed for round-robin reads.
int64 consumer_index = -1;
// Optional round index indicating which round the consumer wants to read
// from. Consumers are expected to read from consecutive rounds, starting
// with round 0. The task runner will attempt to serve all consumer
// requests for a round from the same block of `num_consumers` iterator
// indices, where block `n` is defined as elements `n*num_consumers` to
// `(n+1)*num_consumers`.
int64 round_index = -1;
};
// Creates a `TaskRunner` and stores it in `out`.
static Status Create(const TaskDef& task_def,
std::unique_ptr<TaskIterator> iterator,
std::unique_ptr<TaskRunner>& out);
virtual ~TaskRunner() = default;
// Gets the next element for the given request, storing the results in
// `element` and `end_of_task`.
virtual Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) = 0;
};
// A task runner which provides elements on a first-come first-served basis.
// It does not consider which consumer is making the request.
class FirstComeFirstServedTaskRunner : public TaskRunner {
public:
explicit FirstComeFirstServedTaskRunner(
std::unique_ptr<TaskIterator> iterator);
Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override;
private:
std::unique_ptr<TaskIterator> iterator_;
};
// A task runner which enforces round-robin order for consuming a task's
// elements. Requests must provide a consumer index and element index.
// `RoundRobinTaskRunner` provides elements in a series of "rounds". In each
// successive round, the runner waits to receive requests from all consumers.
// These requests are blocked until all requests arrive. Once all requests
// arrive, the runner hands out elements to consumers in order of their consumer
// indices.
//
// Consumers are expected to successively request consecutive element indices,
// starting at 0. The same element can be requested multiple times by the same
// consumer, as long as the consumer hasn't yet requested the next element (at
// the start of each round we discard elements from the previous round).
//
// If the worker restarts mid-round, a situation arises where some consumers
// are requesting element index `n` while others are requesting element index
// `n + 1`. To remedy this, the first round after restart may be a partial
// round, where we only serve elements to consumers requesting data for element
// index `n`, blocking other consumers until the second round.
class RoundRobinTaskRunner : public TaskRunner {
public:
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
int64 num_consumers);
Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override;
private:
struct Result {
std::vector<Tensor> element;
bool end_of_task = false;
};
// Fills `buffer_` with `num_consumers_` elements.
Status FillBuffer();
const int64 num_consumers_;
std::unique_ptr<TaskIterator> iterator_;
mutex mu_;
// Condition variable notified whenever we start a new round of round-robin.
condition_variable new_round_cv_;
// Map from round number to consumers waiting for data from that round.
absl::flat_hash_map<int64, absl::flat_hash_set<int64>> requests_
TF_GUARDED_BY(mu_);
// Index of the first round we plan to serve. At startup, this is the minimum
// of all requested element indices.
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
// Buffered results for the current round.
std::vector<Result> buffer_ TF_GUARDED_BY(mu_);
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_

View File

@ -0,0 +1,195 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/task_runner.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace data {
namespace {
class TestTaskIterator : public TaskIterator {
public:
explicit TestTaskIterator(const std::vector<std::vector<Tensor>>& elements)
: elements_(elements), index_(0) {}
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
end_of_sequence = index_ >= elements_.size();
if (!end_of_sequence) {
element = elements_[index_++];
}
return Status::OK();
}
private:
std::vector<std::vector<Tensor>> elements_;
int64 index_;
};
// Reads from the task runner, storing results in `*output`.
Status RunConsumer(int64 consumer_index, int64 start_index,
TaskRunner& task_runner,
std::vector<std::vector<Tensor>>& output) {
bool end_of_sequence = false;
int64 next_index = start_index;
while (!end_of_sequence) {
TaskRunner::Request request;
request.round_index = next_index++;
request.consumer_index = consumer_index;
std::vector<Tensor> element;
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence));
if (!end_of_sequence) {
output.push_back(element);
}
}
return Status::OK();
}
} // namespace
TEST(FirstComeFirstServedTaskRunner, GetNext) {
std::vector<std::vector<Tensor>> elements;
for (int64 i = 0; i < 10; ++i) {
std::vector<Tensor> element;
element.push_back(Tensor(i));
elements.push_back(element);
}
FirstComeFirstServedTaskRunner runner(
absl::make_unique<TestTaskIterator>(elements));
TaskRunner::Request request;
for (auto& expected_element : elements) {
std::vector<Tensor> element;
bool end_of_sequence;
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
ASSERT_FALSE(end_of_sequence);
ASSERT_EQ(element.size(), 1);
test::ExpectEqual(element[0], expected_element[0]);
}
for (int i = 0; i < 2; ++i) {
std::vector<Tensor> element;
bool end_of_sequence;
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
ASSERT_TRUE(end_of_sequence);
}
}
class ConsumeParallelTest
: public ::testing::Test,
public ::testing::WithParamInterface<std::tuple<int64, int64>> {};
TEST_P(ConsumeParallelTest, ConsumeParallel) {
int64 num_elements = std::get<0>(GetParam());
int64 num_consumers = std::get<1>(GetParam());
std::vector<std::vector<Tensor>> elements;
for (int64 i = 0; i < num_elements; ++i) {
std::vector<Tensor> element;
element.push_back(Tensor(i));
elements.push_back(element);
}
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
num_consumers);
std::vector<std::vector<std::vector<Tensor>>> per_consumer_results;
std::vector<std::unique_ptr<Thread>> consumers;
mutex mu;
Status error;
for (int consumer = 0; consumer < num_consumers; ++consumer) {
mutex_lock l(mu);
per_consumer_results.emplace_back();
consumers.push_back(absl::WrapUnique(Env::Default()->StartThread(
{}, absl::StrCat("consumer_", consumer), [&, consumer] {
std::vector<std::vector<Tensor>> results;
Status s = RunConsumer(consumer, /*start_index=*/0, runner, results);
mutex_lock l(mu);
if (!s.ok()) {
error = s;
return;
}
per_consumer_results[consumer] = std::move(results);
})));
}
// Wait for all consumers to finish;
consumers.clear();
mutex_lock l(mu);
TF_ASSERT_OK(error);
for (int i = 0; i < num_elements; ++i) {
int consumer = i % num_consumers;
int round = i / num_consumers;
Tensor expected = elements[i][0];
test::ExpectEqual(per_consumer_results[consumer][round][0], expected);
}
}
INSTANTIATE_TEST_SUITE_P(ConsumeParallelTests, ConsumeParallelTest,
// tuples represent <num_elements, num_consumers>
::testing::Values(std::make_tuple(1000, 5),
std::make_tuple(1003, 5),
std::make_tuple(1000, 20),
std::make_tuple(4, 20),
std::make_tuple(0, 20)));
TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
int64 num_elements = 20;
int64 num_consumers = 5;
std::vector<int64> starting_rounds = {12, 11, 11, 12, 12};
int64 min_starting_round = 11;
std::vector<std::vector<Tensor>> elements;
for (int64 i = 0; i < num_elements; ++i) {
std::vector<Tensor> element;
element.push_back(Tensor(i));
elements.push_back(element);
}
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
num_consumers);
std::vector<std::vector<std::vector<Tensor>>> per_consumer_results;
std::vector<std::unique_ptr<Thread>> consumers;
mutex mu;
Status error;
for (int consumer = 0; consumer < num_consumers; ++consumer) {
mutex_lock l(mu);
per_consumer_results.emplace_back();
consumers.push_back(absl::WrapUnique(Env::Default()->StartThread(
{}, absl::StrCat("consumer_", consumer), [&, consumer] {
std::vector<std::vector<Tensor>> results;
Status s =
RunConsumer(consumer, starting_rounds[consumer], runner, results);
mutex_lock l(mu);
if (!s.ok()) {
error = s;
return;
}
per_consumer_results[consumer] = std::move(results);
})));
}
// Wait for all consumers to finish;
consumers.clear();
mutex_lock l(mu);
TF_ASSERT_OK(error);
for (int consumer = 0; consumer < num_consumers; ++consumer) {
auto& results = per_consumer_results[consumer];
int start = consumer;
int expected_elements = num_elements / num_consumers;
if (starting_rounds[consumer] != min_starting_round) {
start += num_consumers;
expected_elements--;
}
ASSERT_EQ(results.size(), expected_elements);
int index = 0;
for (int i = start; i < num_elements; i += num_consumers) {
Tensor expected = elements[i][0];
test::ExpectEqual(results[index++][0], expected);
}
}
}
} // namespace data
} // namespace tensorflow

View File

@ -14,6 +14,15 @@ message ProcessTaskResponse {}
message GetElementRequest { message GetElementRequest {
// The task to fetch an element from. // The task to fetch an element from.
int64 task_id = 1; int64 task_id = 1;
// Optional index to indentify the consumer.
oneof optional_consumer_index {
int64 consumer_index = 2;
}
// Optional round index, indicating which round of round-robin the consumer
// wants to read from. This is used to keep consumers in sync.
oneof optional_round_index {
int64 round_index = 3;
}
} }
message GetElementResponse { message GetElementResponse {

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/split_provider.h" #include "tensorflow/core/data/service/split_provider.h"
#include "tensorflow/core/data/service/task_runner.h"
#include "tensorflow/core/data/service/utils.h" #include "tensorflow/core/data/service/utils.h"
#include "tensorflow/core/data/standalone.h" #include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor.pb.h"
@ -123,11 +124,13 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
return Status::OK(); return Status::OK();
} }
standalone::Dataset::Params params; standalone::Dataset::Params params;
std::unique_ptr<standalone::Dataset> dataset;
std::unique_ptr<standalone::Iterator> iterator;
switch (task.task_def.dataset_case()) { switch (task.task_def.dataset_case()) {
case TaskDef::kDatasetDef: case TaskDef::kDatasetDef:
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph( TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
params, task.task_def.dataset_def().graph(), &task.dataset)); params, task.task_def.dataset_def().graph(), &dataset));
break; break;
case TaskDef::kPath: { case TaskDef::kPath: {
DatasetDef def; DatasetDef def;
@ -139,7 +142,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
dispatcher_->GetDatasetDef(task.task_def.dataset_id(), def)); dispatcher_->GetDatasetDef(task.task_def.dataset_id(), def));
} }
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
standalone::Dataset::FromGraph(params, def.graph(), &task.dataset)); standalone::Dataset::FromGraph(params, def.graph(), &dataset));
break; break;
} }
case TaskDef::DATASET_NOT_SET: case TaskDef::DATASET_NOT_SET:
@ -151,17 +154,22 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
auto split_provider = absl::make_unique<DataServiceSplitProvider>( auto split_provider = absl::make_unique<DataServiceSplitProvider>(
config_.dispatcher_address(), config_.protocol(), config_.dispatcher_address(), config_.protocol(),
task.task_def.job_id(), config_.dispatcher_timeout_ms()); task.task_def.job_id(), config_.dispatcher_timeout_ms());
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider), TF_RETURN_IF_ERROR(
&task.iterator)); dataset->MakeIterator(std::move(split_provider), &iterator));
break; break;
} }
case PARALLEL_EPOCHS: case PARALLEL_EPOCHS:
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator)); TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator));
break; break;
default: default:
return errors::InvalidArgument("Unrecognized processing mode: ", return errors::InvalidArgument("Unrecognized processing mode: ",
task.task_def.processing_mode()); task.task_def.processing_mode());
} }
auto task_iterator = absl::make_unique<StandaloneTaskIterator>(
std::move(dataset), std::move(iterator));
TF_RETURN_IF_ERROR(TaskRunner::Create(task.task_def, std::move(task_iterator),
task.task_runner));
task.initialized = true; task.initialized = true;
VLOG(3) << "Created iterator for task " << task.task_def.task_id(); VLOG(3) << "Created iterator for task " << task.task_def.task_id();
return Status::OK(); return Status::OK();
@ -182,16 +190,25 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
"Worker has not yet registered with dispatcher."); "Worker has not yet registered with dispatcher.");
} }
auto it = tasks_.find(request->task_id()); auto it = tasks_.find(request->task_id());
if (it == tasks_.end() || it->second->finished) { if (it == tasks_.end()) {
response->set_end_of_sequence(true); response->set_end_of_sequence(true);
return Status::OK(); return Status::OK();
} }
auto& task = it->second; auto& task = it->second;
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence)); TaskRunner::Request get_next_request;
if (request->optional_consumer_index_case() ==
GetElementRequest::kConsumerIndex) {
get_next_request.consumer_index = request->consumer_index();
}
if (request->optional_round_index_case() ==
GetElementRequest::kRoundIndex) {
get_next_request.round_index = request->round_index();
}
TF_RETURN_IF_ERROR(
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
if (end_of_sequence) { if (end_of_sequence) {
VLOG(3) << "Reached end_of_sequence for task " << request->task_id(); VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
task->finished = true;
pending_completed_tasks_.insert(request->task_id()); pending_completed_tasks_.insert(request->task_id());
task_completion_cv_.notify_one(); task_completion_cv_.notify_one();
} }

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/task_runner.h"
#include "tensorflow/core/data/service/worker.pb.h" #include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h" #include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status.h"
@ -60,11 +61,7 @@ class DataServiceWorkerImpl {
TaskDef task_def; TaskDef task_def;
mutex mu; mutex mu;
bool initialized TF_GUARDED_BY(mu) = false; bool initialized TF_GUARDED_BY(mu) = false;
bool finished = false; std::unique_ptr<TaskRunner> task_runner;
// TODO(aaudibert): Have standalone::Iterator own a reference to
// standalone::Dataset so that we don't need to store the dataset here.
std::unique_ptr<standalone::Dataset> dataset;
std::unique_ptr<standalone::Iterator> iterator;
}; };
// Sends task status to the dispatcher and checks for dispatcher commands. // Sends task status to the dispatcher and checks for dispatcher commands.