[tf.data service] Server-side implementation for round-robin data consumption.
PiperOrigin-RevId: 345743063 Change-Id: Id2e19d048cf49b37185dfd558d3ab010df23f9bc
This commit is contained in:
parent
6b4ba7fb16
commit
80aa374b54
@ -373,6 +373,31 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "task_runner",
|
||||
srcs = ["task_runner.cc"],
|
||||
hdrs = ["task_runner.h"],
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/data:standalone",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "task_runner_test",
|
||||
srcs = ["task_runner_test.cc"],
|
||||
deps = [
|
||||
":task_runner",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/data:standalone",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_cluster",
|
||||
testonly = True,
|
||||
@ -465,6 +490,7 @@ cc_library(
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_util",
|
||||
":split_provider",
|
||||
":task_runner",
|
||||
":utils",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
|
129
tensorflow/core/data/service/task_runner.cc
Normal file
129
tensorflow/core/data/service/task_runner.cc
Normal file
@ -0,0 +1,129 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
StandaloneTaskIterator::StandaloneTaskIterator(
|
||||
std::unique_ptr<standalone::Dataset> dataset,
|
||||
std::unique_ptr<standalone::Iterator> iterator)
|
||||
: dataset_(std::move(dataset)), iterator_(std::move(iterator)) {}
|
||||
|
||||
Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
|
||||
bool& end_of_sequence) {
|
||||
return iterator_->GetNext(&element, &end_of_sequence);
|
||||
}
|
||||
|
||||
Status TaskRunner::Create(const TaskDef& task_def,
|
||||
std::unique_ptr<TaskIterator> iterator,
|
||||
std::unique_ptr<TaskRunner>& out) {
|
||||
if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
|
||||
out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
|
||||
task_def.num_consumers());
|
||||
} else {
|
||||
out =
|
||||
absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator)
|
||||
: iterator_(std::move(iterator)) {}
|
||||
|
||||
Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
|
||||
std::vector<Tensor>& element,
|
||||
bool& end_of_task) {
|
||||
return iterator_->GetNext(element, end_of_task);
|
||||
}
|
||||
|
||||
RoundRobinTaskRunner::RoundRobinTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
|
||||
: num_consumers_(num_consumers),
|
||||
iterator_(std::move(iterator)),
|
||||
buffer_(num_consumers_) {}
|
||||
|
||||
Status RoundRobinTaskRunner::GetNext(const Request& request,
|
||||
std::vector<Tensor>& element,
|
||||
bool& end_of_task) {
|
||||
if (request.consumer_index < 0 || request.round_index < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"RoundRobinTaskRunner needs to know the consumer index and element "
|
||||
"index of each request.");
|
||||
}
|
||||
if (request.consumer_index >= num_consumers_) {
|
||||
return errors::FailedPrecondition(
|
||||
"Requesting data for consumer index ", request.consumer_index,
|
||||
", but the task is configured for only ", num_consumers_, " consumers");
|
||||
}
|
||||
|
||||
mutex_lock l(mu_);
|
||||
absl::flat_hash_set<int64>& round = requests_[request.round_index];
|
||||
first_round_ = std::min(first_round_, request.round_index);
|
||||
round.insert(request.consumer_index);
|
||||
if (current_round_ < request.round_index && round.size() == num_consumers_) {
|
||||
// This was the last request to arrive, time to start a new round.
|
||||
TF_RETURN_IF_ERROR(FillBuffer());
|
||||
current_round_ = request.round_index;
|
||||
new_round_cv_.notify_all();
|
||||
}
|
||||
if (current_round_ < 0 &&
|
||||
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
|
||||
num_consumers_) {
|
||||
// Indicates that we need a partial round to get consumers back in sync.
|
||||
TF_RETURN_IF_ERROR(FillBuffer());
|
||||
current_round_ = first_round_;
|
||||
new_round_cv_.notify_all();
|
||||
}
|
||||
while (current_round_ < request.round_index) {
|
||||
new_round_cv_.wait(l);
|
||||
}
|
||||
Result& result = buffer_[request.consumer_index];
|
||||
end_of_task = result.end_of_task;
|
||||
if (!end_of_task) {
|
||||
element = std::move(result.element);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int i = 0; i < num_consumers_; ++i) {
|
||||
Result& result = buffer_[i];
|
||||
result.element.clear();
|
||||
TF_RETURN_IF_ERROR(iterator_->GetNext(result.element, result.end_of_task));
|
||||
if (buffer_[i].end_of_task && !buffer_[0].end_of_task) {
|
||||
std::vector<Tensor>& first_element = buffer_[0].element;
|
||||
// Pad out the round with empty elements.
|
||||
buffer_[i].element.clear();
|
||||
for (int c = 0; c < first_element.size(); ++c) {
|
||||
TensorShape shape = first_element[c].shape();
|
||||
if (shape.dims() > 0) {
|
||||
shape.set_dim(0, 0);
|
||||
}
|
||||
buffer_[i].element.push_back(Tensor(first_element[c].dtype(), shape));
|
||||
}
|
||||
buffer_[i].end_of_task = false;
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
143
tensorflow/core/data/service/task_runner.h
Normal file
143
tensorflow/core/data/service/task_runner.h
Normal file
@ -0,0 +1,143 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
|
||||
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// Iterator over a task's elements.
|
||||
class TaskIterator {
|
||||
public:
|
||||
virtual ~TaskIterator() = default;
|
||||
// If the iterator is not yet exhausted, `GetNext` stores the next element in
|
||||
// `element` and sets `end_of_sequence` to `false`. Otherwise, sets
|
||||
// `end_of_sequence to `true`.
|
||||
virtual Status GetNext(std::vector<Tensor>& element,
|
||||
bool& end_of_sequence) = 0;
|
||||
};
|
||||
|
||||
// Implementation of TaskIterator wrapping a standalone iterator.
|
||||
class StandaloneTaskIterator : public TaskIterator {
|
||||
public:
|
||||
// `dataset` should be the dataset that created `iterator`.
|
||||
// StandaloneTaskIterator takes ownership of the dataset to ensures it
|
||||
// lives as long as `iterator`.
|
||||
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,
|
||||
std::unique_ptr<standalone::Iterator> iterator);
|
||||
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<standalone::Dataset> dataset_;
|
||||
std::unique_ptr<standalone::Iterator> iterator_;
|
||||
};
|
||||
|
||||
// Interface for providing elements to task consumers.
|
||||
class TaskRunner {
|
||||
public:
|
||||
struct Request {
|
||||
// Optional consumer index indicating which consumer is making the request.
|
||||
// Only needed for round-robin reads.
|
||||
int64 consumer_index = -1;
|
||||
// Optional round index indicating which round the consumer wants to read
|
||||
// from. Consumers are expected to read from consecutive rounds, starting
|
||||
// with round 0. The task runner will attempt to serve all consumer
|
||||
// requests for a round from the same block of `num_consumers` iterator
|
||||
// indices, where block `n` is defined as elements `n*num_consumers` to
|
||||
// `(n+1)*num_consumers`.
|
||||
int64 round_index = -1;
|
||||
};
|
||||
|
||||
// Creates a `TaskRunner` and stores it in `out`.
|
||||
static Status Create(const TaskDef& task_def,
|
||||
std::unique_ptr<TaskIterator> iterator,
|
||||
std::unique_ptr<TaskRunner>& out);
|
||||
virtual ~TaskRunner() = default;
|
||||
// Gets the next element for the given request, storing the results in
|
||||
// `element` and `end_of_task`.
|
||||
virtual Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) = 0;
|
||||
};
|
||||
|
||||
// A task runner which provides elements on a first-come first-served basis.
|
||||
// It does not consider which consumer is making the request.
|
||||
class FirstComeFirstServedTaskRunner : public TaskRunner {
|
||||
public:
|
||||
explicit FirstComeFirstServedTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator);
|
||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TaskIterator> iterator_;
|
||||
};
|
||||
|
||||
// A task runner which enforces round-robin order for consuming a task's
|
||||
// elements. Requests must provide a consumer index and element index.
|
||||
// `RoundRobinTaskRunner` provides elements in a series of "rounds". In each
|
||||
// successive round, the runner waits to receive requests from all consumers.
|
||||
// These requests are blocked until all requests arrive. Once all requests
|
||||
// arrive, the runner hands out elements to consumers in order of their consumer
|
||||
// indices.
|
||||
//
|
||||
// Consumers are expected to successively request consecutive element indices,
|
||||
// starting at 0. The same element can be requested multiple times by the same
|
||||
// consumer, as long as the consumer hasn't yet requested the next element (at
|
||||
// the start of each round we discard elements from the previous round).
|
||||
//
|
||||
// If the worker restarts mid-round, a situation arises where some consumers
|
||||
// are requesting element index `n` while others are requesting element index
|
||||
// `n + 1`. To remedy this, the first round after restart may be a partial
|
||||
// round, where we only serve elements to consumers requesting data for element
|
||||
// index `n`, blocking other consumers until the second round.
|
||||
class RoundRobinTaskRunner : public TaskRunner {
|
||||
public:
|
||||
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
|
||||
int64 num_consumers);
|
||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) override;
|
||||
|
||||
private:
|
||||
struct Result {
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_task = false;
|
||||
};
|
||||
// Fills `buffer_` with `num_consumers_` elements.
|
||||
Status FillBuffer();
|
||||
|
||||
const int64 num_consumers_;
|
||||
std::unique_ptr<TaskIterator> iterator_;
|
||||
mutex mu_;
|
||||
// Condition variable notified whenever we start a new round of round-robin.
|
||||
condition_variable new_round_cv_;
|
||||
// Map from round number to consumers waiting for data from that round.
|
||||
absl::flat_hash_map<int64, absl::flat_hash_set<int64>> requests_
|
||||
TF_GUARDED_BY(mu_);
|
||||
// Index of the first round we plan to serve. At startup, this is the minimum
|
||||
// of all requested element indices.
|
||||
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
|
||||
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
|
||||
// Buffered results for the current round.
|
||||
std::vector<Result> buffer_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
|
195
tensorflow/core/data/service/task_runner_test.cc
Normal file
195
tensorflow/core/data/service/task_runner_test.cc
Normal file
@ -0,0 +1,195 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
class TestTaskIterator : public TaskIterator {
|
||||
public:
|
||||
explicit TestTaskIterator(const std::vector<std::vector<Tensor>>& elements)
|
||||
: elements_(elements), index_(0) {}
|
||||
|
||||
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
|
||||
end_of_sequence = index_ >= elements_.size();
|
||||
if (!end_of_sequence) {
|
||||
element = elements_[index_++];
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::vector<Tensor>> elements_;
|
||||
int64 index_;
|
||||
};
|
||||
|
||||
// Reads from the task runner, storing results in `*output`.
|
||||
Status RunConsumer(int64 consumer_index, int64 start_index,
|
||||
TaskRunner& task_runner,
|
||||
std::vector<std::vector<Tensor>>& output) {
|
||||
bool end_of_sequence = false;
|
||||
int64 next_index = start_index;
|
||||
while (!end_of_sequence) {
|
||||
TaskRunner::Request request;
|
||||
request.round_index = next_index++;
|
||||
request.consumer_index = consumer_index;
|
||||
std::vector<Tensor> element;
|
||||
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
output.push_back(element);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TEST(FirstComeFirstServedTaskRunner, GetNext) {
|
||||
std::vector<std::vector<Tensor>> elements;
|
||||
for (int64 i = 0; i < 10; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
element.push_back(Tensor(i));
|
||||
elements.push_back(element);
|
||||
}
|
||||
FirstComeFirstServedTaskRunner runner(
|
||||
absl::make_unique<TestTaskIterator>(elements));
|
||||
TaskRunner::Request request;
|
||||
for (auto& expected_element : elements) {
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_sequence;
|
||||
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
|
||||
ASSERT_FALSE(end_of_sequence);
|
||||
ASSERT_EQ(element.size(), 1);
|
||||
test::ExpectEqual(element[0], expected_element[0]);
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_sequence;
|
||||
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
|
||||
ASSERT_TRUE(end_of_sequence);
|
||||
}
|
||||
}
|
||||
|
||||
class ConsumeParallelTest
|
||||
: public ::testing::Test,
|
||||
public ::testing::WithParamInterface<std::tuple<int64, int64>> {};
|
||||
|
||||
TEST_P(ConsumeParallelTest, ConsumeParallel) {
|
||||
int64 num_elements = std::get<0>(GetParam());
|
||||
int64 num_consumers = std::get<1>(GetParam());
|
||||
std::vector<std::vector<Tensor>> elements;
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
element.push_back(Tensor(i));
|
||||
elements.push_back(element);
|
||||
}
|
||||
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
|
||||
num_consumers);
|
||||
std::vector<std::vector<std::vector<Tensor>>> per_consumer_results;
|
||||
std::vector<std::unique_ptr<Thread>> consumers;
|
||||
mutex mu;
|
||||
Status error;
|
||||
for (int consumer = 0; consumer < num_consumers; ++consumer) {
|
||||
mutex_lock l(mu);
|
||||
per_consumer_results.emplace_back();
|
||||
consumers.push_back(absl::WrapUnique(Env::Default()->StartThread(
|
||||
{}, absl::StrCat("consumer_", consumer), [&, consumer] {
|
||||
std::vector<std::vector<Tensor>> results;
|
||||
Status s = RunConsumer(consumer, /*start_index=*/0, runner, results);
|
||||
mutex_lock l(mu);
|
||||
if (!s.ok()) {
|
||||
error = s;
|
||||
return;
|
||||
}
|
||||
per_consumer_results[consumer] = std::move(results);
|
||||
})));
|
||||
}
|
||||
// Wait for all consumers to finish;
|
||||
consumers.clear();
|
||||
mutex_lock l(mu);
|
||||
TF_ASSERT_OK(error);
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
int consumer = i % num_consumers;
|
||||
int round = i / num_consumers;
|
||||
Tensor expected = elements[i][0];
|
||||
test::ExpectEqual(per_consumer_results[consumer][round][0], expected);
|
||||
}
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(ConsumeParallelTests, ConsumeParallelTest,
|
||||
// tuples represent <num_elements, num_consumers>
|
||||
::testing::Values(std::make_tuple(1000, 5),
|
||||
std::make_tuple(1003, 5),
|
||||
std::make_tuple(1000, 20),
|
||||
std::make_tuple(4, 20),
|
||||
std::make_tuple(0, 20)));
|
||||
|
||||
TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
|
||||
int64 num_elements = 20;
|
||||
int64 num_consumers = 5;
|
||||
std::vector<int64> starting_rounds = {12, 11, 11, 12, 12};
|
||||
int64 min_starting_round = 11;
|
||||
std::vector<std::vector<Tensor>> elements;
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
element.push_back(Tensor(i));
|
||||
elements.push_back(element);
|
||||
}
|
||||
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
|
||||
num_consumers);
|
||||
std::vector<std::vector<std::vector<Tensor>>> per_consumer_results;
|
||||
std::vector<std::unique_ptr<Thread>> consumers;
|
||||
mutex mu;
|
||||
Status error;
|
||||
for (int consumer = 0; consumer < num_consumers; ++consumer) {
|
||||
mutex_lock l(mu);
|
||||
per_consumer_results.emplace_back();
|
||||
consumers.push_back(absl::WrapUnique(Env::Default()->StartThread(
|
||||
{}, absl::StrCat("consumer_", consumer), [&, consumer] {
|
||||
std::vector<std::vector<Tensor>> results;
|
||||
Status s =
|
||||
RunConsumer(consumer, starting_rounds[consumer], runner, results);
|
||||
mutex_lock l(mu);
|
||||
if (!s.ok()) {
|
||||
error = s;
|
||||
return;
|
||||
}
|
||||
per_consumer_results[consumer] = std::move(results);
|
||||
})));
|
||||
}
|
||||
// Wait for all consumers to finish;
|
||||
consumers.clear();
|
||||
mutex_lock l(mu);
|
||||
TF_ASSERT_OK(error);
|
||||
for (int consumer = 0; consumer < num_consumers; ++consumer) {
|
||||
auto& results = per_consumer_results[consumer];
|
||||
int start = consumer;
|
||||
int expected_elements = num_elements / num_consumers;
|
||||
if (starting_rounds[consumer] != min_starting_round) {
|
||||
start += num_consumers;
|
||||
expected_elements--;
|
||||
}
|
||||
ASSERT_EQ(results.size(), expected_elements);
|
||||
int index = 0;
|
||||
for (int i = start; i < num_elements; i += num_consumers) {
|
||||
Tensor expected = elements[i][0];
|
||||
test::ExpectEqual(results[index++][0], expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
@ -14,6 +14,15 @@ message ProcessTaskResponse {}
|
||||
message GetElementRequest {
|
||||
// The task to fetch an element from.
|
||||
int64 task_id = 1;
|
||||
// Optional index to indentify the consumer.
|
||||
oneof optional_consumer_index {
|
||||
int64 consumer_index = 2;
|
||||
}
|
||||
// Optional round index, indicating which round of round-robin the consumer
|
||||
// wants to read from. This is used to keep consumers in sync.
|
||||
oneof optional_round_index {
|
||||
int64 round_index = 3;
|
||||
}
|
||||
}
|
||||
|
||||
message GetElementResponse {
|
||||
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/split_provider.h"
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
#include "tensorflow/core/data/service/utils.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -123,11 +124,13 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
return Status::OK();
|
||||
}
|
||||
standalone::Dataset::Params params;
|
||||
std::unique_ptr<standalone::Dataset> dataset;
|
||||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
|
||||
switch (task.task_def.dataset_case()) {
|
||||
case TaskDef::kDatasetDef:
|
||||
TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
|
||||
params, task.task_def.dataset_def().graph(), &task.dataset));
|
||||
params, task.task_def.dataset_def().graph(), &dataset));
|
||||
break;
|
||||
case TaskDef::kPath: {
|
||||
DatasetDef def;
|
||||
@ -139,7 +142,7 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
dispatcher_->GetDatasetDef(task.task_def.dataset_id(), def));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
standalone::Dataset::FromGraph(params, def.graph(), &task.dataset));
|
||||
standalone::Dataset::FromGraph(params, def.graph(), &dataset));
|
||||
break;
|
||||
}
|
||||
case TaskDef::DATASET_NOT_SET:
|
||||
@ -151,17 +154,22 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
auto split_provider = absl::make_unique<DataServiceSplitProvider>(
|
||||
config_.dispatcher_address(), config_.protocol(),
|
||||
task.task_def.job_id(), config_.dispatcher_timeout_ms());
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(std::move(split_provider),
|
||||
&task.iterator));
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset->MakeIterator(std::move(split_provider), &iterator));
|
||||
break;
|
||||
}
|
||||
case PARALLEL_EPOCHS:
|
||||
TF_RETURN_IF_ERROR(task.dataset->MakeIterator(&task.iterator));
|
||||
TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator));
|
||||
break;
|
||||
default:
|
||||
return errors::InvalidArgument("Unrecognized processing mode: ",
|
||||
task.task_def.processing_mode());
|
||||
}
|
||||
auto task_iterator = absl::make_unique<StandaloneTaskIterator>(
|
||||
std::move(dataset), std::move(iterator));
|
||||
TF_RETURN_IF_ERROR(TaskRunner::Create(task.task_def, std::move(task_iterator),
|
||||
task.task_runner));
|
||||
|
||||
task.initialized = true;
|
||||
VLOG(3) << "Created iterator for task " << task.task_def.task_id();
|
||||
return Status::OK();
|
||||
@ -182,16 +190,25 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
"Worker has not yet registered with dispatcher.");
|
||||
}
|
||||
auto it = tasks_.find(request->task_id());
|
||||
if (it == tasks_.end() || it->second->finished) {
|
||||
if (it == tasks_.end()) {
|
||||
response->set_end_of_sequence(true);
|
||||
return Status::OK();
|
||||
}
|
||||
auto& task = it->second;
|
||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||
TF_RETURN_IF_ERROR(task->iterator->GetNext(&outputs, &end_of_sequence));
|
||||
TaskRunner::Request get_next_request;
|
||||
if (request->optional_consumer_index_case() ==
|
||||
GetElementRequest::kConsumerIndex) {
|
||||
get_next_request.consumer_index = request->consumer_index();
|
||||
}
|
||||
if (request->optional_round_index_case() ==
|
||||
GetElementRequest::kRoundIndex) {
|
||||
get_next_request.round_index = request->round_index();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||
task->finished = true;
|
||||
pending_completed_tasks_.insert(request->task_id());
|
||||
task_completion_cv_.notify_one();
|
||||
}
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -60,11 +61,7 @@ class DataServiceWorkerImpl {
|
||||
TaskDef task_def;
|
||||
mutex mu;
|
||||
bool initialized TF_GUARDED_BY(mu) = false;
|
||||
bool finished = false;
|
||||
// TODO(aaudibert): Have standalone::Iterator own a reference to
|
||||
// standalone::Dataset so that we don't need to store the dataset here.
|
||||
std::unique_ptr<standalone::Dataset> dataset;
|
||||
std::unique_ptr<standalone::Iterator> iterator;
|
||||
std::unique_ptr<TaskRunner> task_runner;
|
||||
};
|
||||
|
||||
// Sends task status to the dispatcher and checks for dispatcher commands.
|
||||
|
Loading…
x
Reference in New Issue
Block a user