[tf.data service] Skip round robin rounds when data isn't ready.

PiperOrigin-RevId: 356348715
Change-Id: I8ac227a098d49bd8a3fd6c96b93ac855df80a121
This commit is contained in:
Andrew Audibert 2021-02-08 14:05:34 -08:00 committed by TensorFlower Gardener
parent 0092ebe4c7
commit 2cc0ab3c0c
12 changed files with 404 additions and 256 deletions

View File

@ -57,6 +57,7 @@ tf_proto_library(
],
visibility = [
":data_transfer_visibility",
"//tensorflow:internal",
],
)
@ -95,6 +96,7 @@ cc_library(
":dispatcher_cc_grpc_proto",
":grpc_util",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core/platform:errors",
"@com_google_absl//absl/container:flat_hash_set",
@ -415,6 +417,7 @@ cc_library(
hdrs = ["task_runner.h"],
deps = [
":common_proto_cc",
":worker_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/data:compression_utils",
@ -428,11 +431,14 @@ tf_cc_test(
srcs = ["task_runner_test.cc"],
deps = [
":task_runner",
":worker_proto_cc",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/data:compression_utils",
"//tensorflow/core/data:dataset_proto_cc",
"@com_google_absl//absl/memory",
],
)

View File

@ -248,25 +248,14 @@ class GrpcDataTransferClient : public DataTransferClient {
stub_ = WorkerService::NewStub(channel);
}
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element,
bool& end_of_sequence) override {
Status GetElement(const GetElementRequest& req,
GetElementResponse& resp) override {
{
mutex_lock l(mu_);
if (cancelled_) {
return errors::Cancelled("Client was cancelled.");
}
}
GetElementRequest req;
req.set_task_id(task_id);
if (consumer_index.has_value()) {
req.set_consumer_index(consumer_index.value());
}
if (round_index.has_value()) {
req.set_round_index(round_index.value());
}
GetElementResponse resp;
grpc::ClientContext ctx;
{
mutex_lock l(mu_);
@ -280,10 +269,6 @@ class GrpcDataTransferClient : public DataTransferClient {
if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s);
}
end_of_sequence = resp.end_of_sequence();
if (!end_of_sequence) {
element = std::move(*resp.mutable_compressed_element());
}
return Status::OK();
}
@ -324,14 +309,10 @@ class GrpcTransferClientRegistrar {
};
static GrpcTransferClientRegistrar registrar;
Status DataServiceWorkerClient::GetElement(int64 task_id,
absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element,
bool& end_of_sequence) {
Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
GetElementResponse& resp) {
TF_RETURN_IF_ERROR(EnsureInitialized());
return client_->GetElement(task_id, consumer_index, round_index, element,
end_of_sequence);
return client_->GetElement(req, resp);
}
Status DataServiceWorkerClient::EnsureInitialized() {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/data/service/data_transfer.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@ -144,14 +145,8 @@ class DataServiceWorkerClient : public DataServiceClientBase {
: DataServiceClientBase(address, protocol),
transfer_protocol_(transfer_protocol) {}
// Fetches the next element for the specified task_id. The optional
// `consumer_index` and `round_index` must be specified for tasks which use
// round-robin ordering. The element's compressed tensors will be stored in
// `element`. If no element is available, `end_of_sequence` will be `true`,
// and `element` will be left unchanged.
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element, bool& end_of_sequence);
// Fetches an element from the worker.
Status GetElement(const GetElementRequest& req, GetElementResponse& resp);
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.

View File

@ -38,13 +38,9 @@ class DataTransferClient {
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
virtual ~DataTransferClient() = default;
// Fetches the next element for the specified task_id. The element's
// compressed tensors will be stored in `element`. If no element is available,
// `end_of_sequence` will be `true`, and `element` will be left unchanged.
virtual Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
tensorflow::data::CompressedElement& element,
bool& end_of_sequence) = 0;
// Fetches the next element.
virtual Status GetElement(const GetElementRequest& req,
GetElementResponse& resp) = 0;
// Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status.

View File

@ -29,6 +29,43 @@ namespace {
// Unavailable error. This prevents the server from hanging on shutdown when
// some round-robin consumers exit earlier than others.
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
// Time to wait before skipping a round if data still isn't available.
const int64 kWaitBeforeSkipUs = 100 * 1000; // 100ms.
// Interprets `element` as a size-1 vector containing a CompressedElement, and
// moves the element into `resp`. Returns an error if `element` is of unexpected
// size, type, or shape.
Status MoveCompressedElement(std::vector<Tensor>&& element,
GetElementResponse& resp) {
if (element.size() != 1) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but the "
"dataset produced ",
element.size(), " outputs");
}
if (element[0].dtype() != DT_VARIANT) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with type ",
DataTypeString(element[0].dtype()));
}
if (!TensorShapeUtils::IsScalar(element[0].shape())) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with shape ",
element[0].shape());
}
Variant& variant = element[0].scalar<Variant>()();
CompressedElement* compressed = variant.get<CompressedElement>();
if (compressed == nullptr) {
return errors::FailedPrecondition(
"Expected dataset to produce a CompressedElement variant tensor, but "
"it produced ",
variant.TypeName());
}
*resp.mutable_compressed_element() = *compressed;
return Status::OK();
}
} // namespace
StandaloneTaskIterator::StandaloneTaskIterator(
@ -71,97 +108,211 @@ FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
std::unique_ptr<TaskIterator> iterator)
: iterator_(std::move(iterator)) {}
Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
std::vector<Tensor>& element,
bool& end_of_task) {
return iterator_->GetNext(element, end_of_task);
Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
GetElementResponse& resp) {
std::vector<Tensor> element;
bool end_of_task;
resp.set_skip_task(false);
TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
resp.set_end_of_sequence(end_of_task);
if (!end_of_task) {
return MoveCompressedElement(std::move(element), resp);
}
return Status::OK();
}
RoundRobinTaskRunner::RoundRobinTaskRunner(
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
: num_consumers_(num_consumers),
iterator_(std::move(iterator)),
buffer_(num_consumers_) {
buffer_(num_consumers_),
prefetch_thread_(std::move(iterator), num_consumers_) {
VLOG(1) << "Creating task runner for distributing data round-robin to "
<< num_consumers << " consumers";
}
Status RoundRobinTaskRunner::GetNext(const Request& request,
std::vector<Tensor>& element,
bool& end_of_task) {
if (request.consumer_index < 0 || request.round_index < 0) {
Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
if (req.consumer_index() < 0 || req.round_index() < 0) {
return errors::FailedPrecondition(
"RoundRobinTaskRunner needs to know the consumer index and element "
"index of each request.");
}
if (request.consumer_index >= num_consumers_) {
if (req.consumer_index() >= num_consumers_) {
return errors::FailedPrecondition(
"Requesting data for consumer index ", request.consumer_index,
"Requesting data for consumer index ", req.consumer_index(),
", but the task is configured for only ", num_consumers_, " consumers");
}
VLOG(2) << "Received request from consumer index " << request.consumer_index
<< " for round " << request.round_index;
{
mutex_lock l(mu_);
absl::flat_hash_set<int64>& round = requests_[request.round_index];
first_round_ = std::min(first_round_, request.round_index);
round.insert(request.consumer_index);
if (current_round_ < request.round_index &&
round.size() == num_consumers_) {
VLOG(1) << "Starting normal round with round index "
<< request.round_index;
// This was the last request to arrive, time to start a new round.
TF_RETURN_IF_ERROR(FillBuffer());
VLOG(1) << "Finished preparing data for round " << request.round_index;
current_round_ = request.round_index;
new_round_cv_.notify_all();
}
if (current_round_ < 0 &&
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
num_consumers_) {
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
<< " consumers";
// Indicates that we need a partial round to get consumers back in sync.
TF_RETURN_IF_ERROR(FillBuffer());
current_round_ = first_round_;
new_round_cv_.notify_all();
}
while (current_round_ < request.round_index) {
std::cv_status s =
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
if (s == std::cv_status::timeout) {
// Clients will retry Unavailable.
return errors::Unavailable(
"Timeout waiting for other round-robin consumers to be ready.");
}
}
end_of_task = end_of_task_;
}
if (!end_of_task) {
element.clear();
tf_shared_lock l(mu_);
for (auto& component : buffer_[request.consumer_index]) {
element.push_back(tensor::DeepCopy(component));
}
}
VLOG(2) << "Returning to consumer " << request.consumer_index << " for round "
<< request.round_index;
return Status::OK();
}
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (int i = 0; i < num_consumers_; ++i) {
buffer_[i].clear();
bool end_of_sequence;
TF_RETURN_IF_ERROR(iterator_->GetNext(buffer_[i], end_of_sequence));
if (end_of_sequence) {
return errors::FailedPrecondition(
"Encountered end of sequence on a round-robin read iterator. Please "
"ensure that the dataset used for round-robin reading has infinite "
"cardinality, e.g. by adding a .repeat() transformation at the end.");
Status RoundRobinTaskRunner::PrepareFullRound(int64 wait_us)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(1) << "Preparing full round for index " << current_round_;
// This was the last request to arrive, time to start a new round.
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
round_skipped_ = buffer_.empty();
new_round_cv_.notify_all();
return Status::OK();
}
Status RoundRobinTaskRunner::PreparePartialRound()
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
<< " consumers";
current_round_ = first_round_;
new_round_cv_.notify_all();
// Indicates that we need a partial round to get consumers back in sync.
auto next_round_request = *(requests_[first_round_ + 1].begin());
if (next_round_request->skipped_previous_round()) {
VLOG(1) << "Skipping partial round";
round_skipped_ = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
round_skipped_ = false;
return Status::OK();
}
Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
mutex_lock l(mu_);
absl::flat_hash_set<const GetElementRequest*>& round =
requests_[req.round_index()];
first_round_ = std::min(first_round_, req.round_index());
round.insert(&req);
if (current_round_ < req.round_index() && round.size() == num_consumers_) {
current_round_ = req.round_index();
int64 wait_us = kWaitBeforeSkipUs;
if (!req.allow_skip()) {
wait_us = -1;
}
TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
}
if (current_round_ < 0 &&
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
num_consumers_) {
TF_RETURN_IF_ERROR(PreparePartialRound());
}
while (current_round_ < req.round_index()) {
TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
std::cv_status s =
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
if (s == std::cv_status::timeout) {
// Clients will retry Unavailable.
return errors::Unavailable(
"Timeout waiting for other round-robin consumers to be ready.");
}
}
return prefetch_thread_.GetStatus();
}
Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
GetElementResponse& resp) {
TF_RETURN_IF_ERROR(ValidateRequest(req));
resp.set_end_of_sequence(false);
VLOG(2) << "Received request from consumer index " << req.consumer_index()
<< " for round " << req.round_index();
TF_RETURN_IF_ERROR(PrepareRound(req));
tf_shared_lock l(mu_);
resp.set_skip_task(round_skipped_);
if (round_skipped_) {
VLOG(1) << "Buffer not ready, skipping round " << current_round_
<< " for consumer " << req.consumer_index();
return Status::OK();
}
std::vector<Tensor> element;
for (auto& component : buffer_[req.consumer_index()]) {
element.push_back(tensor::DeepCopy(component));
}
if (VLOG_IS_ON(2)) {
int64 size = 0;
for (auto& component : element) {
size += component.TotalBytes();
}
VLOG(2) << "Returning to consumer " << req.consumer_index() << " for round "
<< req.round_index() << ". element size " << size;
}
return MoveCompressedElement(std::move(element), resp);
}
PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
int64 round_size)
: iterator_(std::move(iterator)), round_size_(round_size) {
thread_ = absl::WrapUnique(
Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
}
PrefetchThread::~PrefetchThread() {
mutex_lock l(mu_);
cancelled_ = true;
cv_.notify_all();
}
void PrefetchThread::Run() {
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= round_size_) {
cv_.wait(l);
}
if (cancelled_) {
return;
}
}
std::vector<Tensor> element;
bool end_of_sequence;
Status s = iterator_->GetNext(element, end_of_sequence);
if (!s.ok()) {
mutex_lock l(mu_);
status_ = s;
cv_.notify_all();
return;
}
if (end_of_sequence) {
mutex_lock l(mu_);
status_ = errors::FailedPrecondition(
"Encountered end of sequence on a round-robin read iterator. "
"Please ensure that the dataset used for round-robin reading has "
"infinite cardinality, e.g. by adding a .repeat() transformation "
"at the end.");
cv_.notify_all();
return;
}
mutex_lock l(mu_);
buffer_.push_back(std::move(element));
cv_.notify_all();
}
}
Status PrefetchThread::FillBuffer(int64 wait_us,
std::vector<std::vector<Tensor>>& out) {
int64 start_us = Env::Default()->NowMicros();
out.clear();
mutex_lock l(mu_);
while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
int64 remaining_us = start_us + wait_us - Env::Default()->NowMicros();
if (wait_us >= 0 && remaining_us <= 0) {
break;
}
cv_.wait_for(l, std::chrono::microseconds(remaining_us));
}
TF_RETURN_IF_ERROR(status_);
if (cancelled_) {
return errors::Cancelled("Prefetch thread cancelled");
}
if (buffer_.size() < round_size_) {
DCHECK_GE(wait_us, 0);
return Status::OK();
}
for (auto& elem : buffer_) {
out.push_back(std::move(elem));
}
buffer_.clear();
cv_.notify_all();
return Status::OK();
}
Status PrefetchThread::GetStatus() {
mutex_lock l(mu_);
return status_;
}
} // namespace data
} // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/platform/status.h"
@ -54,28 +55,14 @@ class StandaloneTaskIterator : public TaskIterator {
// Interface for providing elements to task consumers.
class TaskRunner {
public:
struct Request {
// Optional consumer index indicating which consumer is making the request.
// Only needed for round-robin reads.
int64 consumer_index = -1;
// Optional round index indicating which round the consumer wants to read
// from. Consumers are expected to read from consecutive rounds, starting
// with round 0. The task runner will attempt to serve all consumer
// requests for a round from the same block of `num_consumers` iterator
// indices, where block `n` is defined as elements `n*num_consumers` to
// `(n+1)*num_consumers`.
int64 round_index = -1;
};
// Creates a `TaskRunner` and stores it in `out`.
static Status Create(const TaskDef& task_def,
std::unique_ptr<TaskIterator> iterator,
std::unique_ptr<TaskRunner>& out);
virtual ~TaskRunner() = default;
// Gets the next element for the given request, storing the results in
// `element` and `end_of_task`.
virtual Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) = 0;
// Gets the next element for the given request.
virtual Status GetNext(const GetElementRequest& req,
GetElementResponse& resp) = 0;
};
// A task runner which provides elements on a first-come first-served basis.
@ -84,20 +71,52 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
public:
explicit FirstComeFirstServedTaskRunner(
std::unique_ptr<TaskIterator> iterator);
Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override;
Status GetNext(const GetElementRequest& req,
GetElementResponse& resp) override;
private:
std::unique_ptr<TaskIterator> iterator_;
};
// Thread for prefetching a round worth of elements.
class PrefetchThread {
public:
explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
int64 round_size);
~PrefetchThread();
// Runs the prefetch thread. It runs until an error is encountered or the
// destructor is called.
void Run();
// Fills `out` with a round of data. Waits for up to `wait_us` micoseconds
// before giving up and returning with `out` empty. A negative `wait_us`
// signals to wait indefinitely.
Status FillBuffer(int64 wait_us, std::vector<std::vector<Tensor>>& out);
// Returns the status for any failures encountered by the prefetch thread.
Status GetStatus();
private:
const std::unique_ptr<TaskIterator> iterator_;
const int64 round_size_;
mutex mu_;
// Buffered results for the next round.
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
// The status if the prefetch thread fails.
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
// Thread which constantly tries to fill `buffer_` up with
// `num_consumers` elements.
std::unique_ptr<Thread> thread_;
// Condition variable notified when elements are added to or removed from
// `buffer_`, or when `status_` is changed.
condition_variable cv_;
bool cancelled_ TF_GUARDED_BY(mu_) = false;
};
// A task runner which enforces round-robin order for consuming a task's
// elements. Requests must provide a consumer index and element index.
// `RoundRobinTaskRunner` provides elements in a series of "rounds". In each
// successive round, the runner waits to receive requests from all consumers.
// These requests are blocked until all requests arrive. Once all requests
// arrive, the runner hands out elements to consumers in order of their consumer
// indices.
// elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
// In each successive round, the runner waits to receive requests from all
// consumers. These requests are blocked until all requests arrive. Once all
// requests arrive, the runner hands out elements to consumers in order of their
// consumer indices.
//
// Consumers are expected to successively request consecutive element indices,
// starting at 0. The same element can be requested multiple times by the same
@ -113,28 +132,37 @@ class RoundRobinTaskRunner : public TaskRunner {
public:
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
int64 num_consumers);
Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override;
Status GetNext(const GetElementRequest& req,
GetElementResponse& resp) override;
private:
// Fills `buffer_` with `num_consumers_` elements.
Status FillBuffer();
// Prepares a full round of data. `wait_us` indicates how long to wait before
// skipping if a full round of data is not yet ready.
Status PrepareFullRound(int64 wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Prepares a partial round to get consumers back in sync.
Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status ValidateRequest(const GetElementRequest& req);
// Prepares data for the next round, blocking until the round is ready to
// start.
Status PrepareRound(const GetElementRequest& req);
const int64 num_consumers_;
std::unique_ptr<TaskIterator> iterator_;
mutex mu_;
// Condition variable notified whenever we start a new round of round-robin.
condition_variable new_round_cv_;
// Map from round number to consumers waiting for data from that round.
absl::flat_hash_map<int64, absl::flat_hash_set<int64>> requests_
TF_GUARDED_BY(mu_);
// Map from round number to requests waiting for data from that round.
absl::flat_hash_map<int64, absl::flat_hash_set<const GetElementRequest*>>
requests_ TF_GUARDED_BY(mu_);
// Index of the first round we plan to serve. At startup, this is the minimum
// of all requested element indices.
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
bool round_skipped_ TF_GUARDED_BY(mu_) = false;
// Buffered results for the current round.
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
bool end_of_task_ TF_GUARDED_BY(mu_) = false;
// Thread which constantly tries to prepare `num_consumers` elements for the
// next round.
PrefetchThread prefetch_thread_;
};
} // namespace data

View File

@ -13,6 +13,9 @@ limitations under the License.
#include "tensorflow/core/data/service/task_runner.h"
#include "absl/memory/memory.h"
#include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@ -31,7 +34,10 @@ class TestTaskIterator : public TaskIterator {
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
end_of_sequence = index_ >= elements_.size();
if (!end_of_sequence) {
element = elements_[index_];
CompressedElement compressed;
TF_RETURN_IF_ERROR(CompressElement(elements_[index_], &compressed));
element.emplace_back(DT_VARIANT, TensorShape({}));
element[0].scalar<Variant>()() = std::move(compressed);
index_ = (index_ + 1) % elements_.size();
}
return Status::OK();
@ -47,16 +53,22 @@ class TestTaskIterator : public TaskIterator {
// Reads from the task runner, storing results in `*output`.
Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index,
TaskRunner& task_runner, std::vector<int64>& output) {
bool end_of_sequence = false;
for (int64 next_index = start_index; next_index < end_index; ++next_index) {
TaskRunner::Request request;
request.round_index = next_index;
request.consumer_index = consumer_index;
std::vector<Tensor> element;
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence));
if (!end_of_sequence) {
output.push_back(element[0].flat<int64>()(0));
}
GetElementRequest request;
request.set_round_index(next_index);
request.set_consumer_index(consumer_index);
request.set_skipped_previous_round(false);
request.set_allow_skip(false);
GetElementResponse response;
do {
TF_RETURN_IF_ERROR(task_runner.GetNext(request, response));
if (!response.end_of_sequence()) {
std::vector<Tensor> uncompressed;
TF_RETURN_IF_ERROR(
UncompressElement(response.compressed_element(), &uncompressed));
output.push_back(uncompressed[0].flat<int64>()(0));
}
} while (response.skip_task());
}
return Status::OK();
}
@ -71,12 +83,13 @@ TEST(FirstComeFirstServedTaskRunner, GetNext) {
}
FirstComeFirstServedTaskRunner runner(
absl::make_unique<TestTaskIterator>(elements));
TaskRunner::Request request;
GetElementRequest request;
GetElementResponse response;
for (auto& expected_element : elements) {
TF_ASSERT_OK(runner.GetNext(request, response));
ASSERT_FALSE(response.end_of_sequence());
std::vector<Tensor> element;
bool end_of_sequence;
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
ASSERT_FALSE(end_of_sequence);
TF_ASSERT_OK(UncompressElement(response.compressed_element(), &element));
ASSERT_EQ(element.size(), 1);
test::ExpectEqual(element[0], expected_element[0]);
}

View File

@ -23,6 +23,11 @@ message GetElementRequest {
oneof optional_round_index {
int64 round_index = 3;
}
// Whether the previous round was skipped. This information is needed by the
// worker to recover after restarts.
bool skipped_previous_round = 4;
// Whether to skip the round if data isn't ready fast enough.
bool allow_skip = 5;
}
message GetElementResponse {
@ -30,6 +35,8 @@ message GetElementResponse {
CompressedElement compressed_element = 3;
// Boolean to indicate whether the iterator has been exhausted.
bool end_of_sequence = 2;
// Indicates whether the round was skipped.
bool skip_task = 4;
}
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto

View File

@ -180,8 +180,6 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
GetElementResponse* response) {
VLOG(3) << "Received GetElement request for task " << request->task_id();
bool end_of_sequence = false;
std::vector<tensorflow::Tensor> outputs;
Task* task;
{
mutex_lock l(mu_);
@ -207,54 +205,15 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
task = it->second.get();
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
}
TaskRunner::Request get_next_request;
if (request->optional_consumer_index_case() ==
GetElementRequest::kConsumerIndex) {
get_next_request.consumer_index = request->consumer_index();
}
if (request->optional_round_index_case() == GetElementRequest::kRoundIndex) {
get_next_request.round_index = request->round_index();
}
TF_RETURN_IF_ERROR(
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
if (end_of_sequence) {
TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *response));
if (response->end_of_sequence()) {
mutex_lock l(mu_);
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
pending_completed_tasks_.insert(request->task_id());
task_completion_cv_.notify_one();
}
if (!end_of_sequence) {
} else if (!response->skip_task()) {
VLOG(3) << "Producing an element for task " << request->task_id();
if (outputs.size() != 1) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but the "
"dataset produced ",
outputs.size(), " outputs");
}
if (outputs[0].dtype() != DT_VARIANT) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with type ",
DataTypeString(outputs[0].dtype()));
}
if (!TensorShapeUtils::IsScalar(outputs[0].shape())) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with shape ",
outputs[0].shape());
}
Variant& variant = outputs[0].scalar<Variant>()();
CompressedElement* compressed = variant.get<CompressedElement>();
if (compressed == nullptr) {
return errors::FailedPrecondition(
"Expected dataset to produce a CompressedElement variant tensor, but "
"it produced ",
variant.TypeName());
}
*response->mutable_compressed_element() = *compressed;
}
response->set_end_of_sequence(end_of_sequence);
return Status::OK();
}

View File

@ -168,6 +168,7 @@ tf_kernel_library(
"//tensorflow/core/data/service:data_service",
"//tensorflow/core/data/service:dispatcher_proto_cc",
"//tensorflow/core/data/service:grpc_util",
"//tensorflow/core/data/service:worker_proto_cc",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/model.h"
@ -295,29 +296,37 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
});
}
while ((results_.empty() || !results_.front().ready) &&
!(job_finished_ && num_running_worker_threads_ == 0) &&
!cancelled_ && status_.ok()) {
VLOG(3) << "Blocking in GetNext. results_.size():" << results_.size()
<< " results_.front().ready:"
<< (!results_.empty() && results_.front().ready)
<< " job_finished_:" << job_finished_
<< " num_running_worker_threads_:"
<< num_running_worker_threads_;
get_next_cv_.wait(l);
}
if (cancelled_) {
VLOG(3) << "Returning from GetNext due to cancellation";
return errors::Cancelled("Data service iterator was cancelled");
}
if (!status_.ok()) {
VLOG(3) << "Returning from GetNext with error " << status_;
return status_;
}
if (results_.empty()) {
*end_of_sequence = true;
VLOG(3) << "Returning from GetNext with end_of_sequence";
return Status::OK();
bool skip = true;
while (skip) {
while ((results_.empty() || !results_.front().ready) &&
!(job_finished_ && num_running_worker_threads_ == 0) &&
!cancelled_ && status_.ok()) {
VLOG(3) << "Blocking in GetNext. results_.size():" << results_.size()
<< " results_.front().ready:"
<< (!results_.empty() && results_.front().ready)
<< " job_finished_:" << job_finished_
<< " num_running_worker_threads_:"
<< num_running_worker_threads_;
get_next_cv_.wait(l);
}
if (cancelled_) {
VLOG(3) << "Returning from GetNext due to cancellation";
return errors::Cancelled("Data service iterator was cancelled");
}
if (!status_.ok()) {
VLOG(3) << "Returning from GetNext with error " << status_;
return status_;
}
if (results_.empty()) {
*end_of_sequence = true;
VLOG(3) << "Returning from GetNext with end_of_sequence";
return Status::OK();
}
skip = results_.front().skip;
if (skip) {
results_.pop();
worker_thread_cv_.notify_one();
}
}
*end_of_sequence = results_.front().end_of_sequence;
if (!*end_of_sequence) {
@ -378,6 +387,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<DataServiceWorkerClient> worker;
// The next round to read from the task.
int64 round = 0;
bool skipped_previous_round = false;
// Indicates whether a worker thread is currently processing the task.
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
// Indicates whether the worker has returned end_of_sequence for the task.
@ -390,6 +400,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
bool skip TF_GUARDED_BY(&Iterator::mu_) = false;
};
// Periodically refresh the task list.
@ -677,25 +688,24 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return profiler::TraceMeEncode(
{{"address", task->info.worker_address()}});
});
CompressedElement compressed;
bool end_of_sequence;
GetElementResponse resp;
for (int num_retries = 0;; ++num_retries) {
absl::optional<int64> consumer_index = dataset()->consumer_index_;
absl::optional<int64> round_index;
GetElementRequest req;
if (StrictRoundRobin()) {
round_index = task->round;
req.set_consumer_index(dataset()->consumer_index_.value());
req.set_round_index(task->round);
req.set_allow_skip(true);
VLOG(3) << "Requesting element from consumer index "
<< consumer_index.value() << ", round "
<< round_index.value();
<< req.consumer_index() << ", round " << req.round_index();
activity.AppendMetadata([&]() {
return profiler::TraceMeEncode(
{{"consumer_index", consumer_index.value()},
{"round_index", round_index.value()}});
{{"consumer_index", req.consumer_index()},
{"round_index", req.round_index()}});
});
}
Status s =
task->worker->GetElement(task->info.task_id(), consumer_index,
round_index, compressed, end_of_sequence);
req.set_task_id(task->info.task_id());
req.set_skipped_previous_round(task->skipped_previous_round);
Status s = task->worker->GetElement(req, resp);
if (s.ok()) {
break;
}
@ -709,7 +719,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// If `UpdateTaskThreads` finds that the task has been cancelled, it
// will set end_of_sequence to `true`.
if (task->end_of_sequence || cancelled_) {
end_of_sequence = true;
break;
}
}
@ -734,21 +743,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}
std::vector<Tensor> element;
if (!end_of_sequence) {
if (resp.has_compressed_element()) {
Tensor tensor(DT_VARIANT, TensorShape{});
tensor.scalar<Variant>()() = std::move(compressed);
tensor.scalar<Variant>()() = std::move(resp.compressed_element());
element.push_back(tensor);
}
mutex_lock l(mu_);
result.ready = true;
result.end_of_sequence = end_of_sequence;
if (end_of_sequence) {
result.end_of_sequence = resp.end_of_sequence();
if (resp.has_compressed_element()) {
task->skipped_previous_round = false;
task->round++;
result.element = std::move(element);
} else if (resp.skip_task()) {
task->skipped_previous_round = true;
task->round++;
result.skip = true;
} else {
task->end_of_sequence = true;
finished_tasks_++;
return Status::OK();
}
result.element = std::move(element);
if (enqueue_result) {
if (enqueue_result && !resp.end_of_sequence()) {
results_.push(std::move(result));
}
get_next_cv_.notify_all();

View File

@ -291,8 +291,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
cluster = self.create_cluster(num_workers=num_workers)
# Round robin reads can cause slow cluster shutdown.
data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
num_elements = 100
ds = dataset_ops.Dataset.range(num_elements)
ds = dataset_ops.Dataset.range(10000000)
ds = ds.repeat()
consumers = []
for consumer_index in range(num_consumers):
@ -309,18 +308,15 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
lambda x: x,
cycle_length=num_consumers,
num_parallel_calls=num_consumers)
ds = ds.take(2 * num_elements * num_workers)
ds = ds.take(1000)
results = self.getDatasetOutput(ds, requires_initialization=True)
expected = []
round_index = 0
while len(expected) < len(results):
for _ in range(num_workers):
for consumer in range(num_consumers):
expected.append(
(round_index * num_consumers + consumer) % num_elements)
round_index += 1
self.assertEqual(results, expected)
for i in range(0, len(results), num_consumers):
self.assertEqual(0, results[i] % num_consumers)
# Check that each group of `num_consumers` results are consecutive.
for offset in range(1, num_consumers):
if i + offset < len(results):
self.assertEqual(results[i] + offset, results[i + offset])
@combinations.generate(test_base.default_test_combinations())
def testRoundRobinBucketizing(self):