[tf.data service] Skip round robin rounds when data isn't ready.

PiperOrigin-RevId: 356348715
Change-Id: I8ac227a098d49bd8a3fd6c96b93ac855df80a121
This commit is contained in:
Andrew Audibert 2021-02-08 14:05:34 -08:00 committed by TensorFlower Gardener
parent 0092ebe4c7
commit 2cc0ab3c0c
12 changed files with 404 additions and 256 deletions

View File

@ -57,6 +57,7 @@ tf_proto_library(
], ],
visibility = [ visibility = [
":data_transfer_visibility", ":data_transfer_visibility",
"//tensorflow:internal",
], ],
) )
@ -95,6 +96,7 @@ cc_library(
":dispatcher_cc_grpc_proto", ":dispatcher_cc_grpc_proto",
":grpc_util", ":grpc_util",
":worker_cc_grpc_proto", ":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core/platform:errors", "//tensorflow/core/platform:errors",
"@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/container:flat_hash_set",
@ -415,6 +417,7 @@ cc_library(
hdrs = ["task_runner.h"], hdrs = ["task_runner.h"],
deps = [ deps = [
":common_proto_cc", ":common_proto_cc",
":worker_proto_cc",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core/data:compression_utils", "//tensorflow/core/data:compression_utils",
@ -428,11 +431,14 @@ tf_cc_test(
srcs = ["task_runner_test.cc"], srcs = ["task_runner_test.cc"],
deps = [ deps = [
":task_runner", ":task_runner",
":worker_proto_cc",
"//tensorflow/core:framework", "//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"//tensorflow/core:testlib", "//tensorflow/core:testlib",
"//tensorflow/core/data:compression_utils",
"//tensorflow/core/data:dataset_proto_cc",
"@com_google_absl//absl/memory", "@com_google_absl//absl/memory",
], ],
) )

View File

@ -248,25 +248,14 @@ class GrpcDataTransferClient : public DataTransferClient {
stub_ = WorkerService::NewStub(channel); stub_ = WorkerService::NewStub(channel);
} }
Status GetElement(int64 task_id, absl::optional<int64> consumer_index, Status GetElement(const GetElementRequest& req,
absl::optional<int64> round_index, GetElementResponse& resp) override {
CompressedElement& element,
bool& end_of_sequence) override {
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (cancelled_) { if (cancelled_) {
return errors::Cancelled("Client was cancelled."); return errors::Cancelled("Client was cancelled.");
} }
} }
GetElementRequest req;
req.set_task_id(task_id);
if (consumer_index.has_value()) {
req.set_consumer_index(consumer_index.value());
}
if (round_index.has_value()) {
req.set_round_index(round_index.value());
}
GetElementResponse resp;
grpc::ClientContext ctx; grpc::ClientContext ctx;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -280,10 +269,6 @@ class GrpcDataTransferClient : public DataTransferClient {
if (!s.ok()) { if (!s.ok()) {
return grpc_util::WrapError("Failed to get element", s); return grpc_util::WrapError("Failed to get element", s);
} }
end_of_sequence = resp.end_of_sequence();
if (!end_of_sequence) {
element = std::move(*resp.mutable_compressed_element());
}
return Status::OK(); return Status::OK();
} }
@ -324,14 +309,10 @@ class GrpcTransferClientRegistrar {
}; };
static GrpcTransferClientRegistrar registrar; static GrpcTransferClientRegistrar registrar;
Status DataServiceWorkerClient::GetElement(int64 task_id, Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
absl::optional<int64> consumer_index, GetElementResponse& resp) {
absl::optional<int64> round_index,
CompressedElement& element,
bool& end_of_sequence) {
TF_RETURN_IF_ERROR(EnsureInitialized()); TF_RETURN_IF_ERROR(EnsureInitialized());
return client_->GetElement(task_id, consumer_index, round_index, element, return client_->GetElement(req, resp);
end_of_sequence);
} }
Status DataServiceWorkerClient::EnsureInitialized() { Status DataServiceWorkerClient::EnsureInitialized() {

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/data/service/data_transfer.h" #include "tensorflow/core/data/service/data_transfer.h"
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h" #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
#include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -144,14 +145,8 @@ class DataServiceWorkerClient : public DataServiceClientBase {
: DataServiceClientBase(address, protocol), : DataServiceClientBase(address, protocol),
transfer_protocol_(transfer_protocol) {} transfer_protocol_(transfer_protocol) {}
// Fetches the next element for the specified task_id. The optional // Fetches an element from the worker.
// `consumer_index` and `round_index` must be specified for tasks which use Status GetElement(const GetElementRequest& req, GetElementResponse& resp);
// round-robin ordering. The element's compressed tensors will be stored in
// `element`. If no element is available, `end_of_sequence` will be `true`,
// and `element` will be left unchanged.
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
CompressedElement& element, bool& end_of_sequence);
// Makes a best effort to cancel all outstanding calls in progress for the // Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status. // client, and causes further calls to return Cancelled status.

View File

@ -38,13 +38,9 @@ class DataTransferClient {
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>; std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
virtual ~DataTransferClient() = default; virtual ~DataTransferClient() = default;
// Fetches the next element for the specified task_id. The element's // Fetches the next element.
// compressed tensors will be stored in `element`. If no element is available, virtual Status GetElement(const GetElementRequest& req,
// `end_of_sequence` will be `true`, and `element` will be left unchanged. GetElementResponse& resp) = 0;
virtual Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
absl::optional<int64> round_index,
tensorflow::data::CompressedElement& element,
bool& end_of_sequence) = 0;
// Makes a best effort to cancel all outstanding calls in progress for the // Makes a best effort to cancel all outstanding calls in progress for the
// client, and causes further calls to return Cancelled status. // client, and causes further calls to return Cancelled status.

View File

@ -29,6 +29,43 @@ namespace {
// Unavailable error. This prevents the server from hanging on shutdown when // Unavailable error. This prevents the server from hanging on shutdown when
// some round-robin consumers exit earlier than others. // some round-robin consumers exit earlier than others.
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute. const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
// Time to wait before skipping a round if data still isn't available.
const int64 kWaitBeforeSkipUs = 100 * 1000; // 100ms.
// Interprets `element` as a size-1 vector containing a CompressedElement, and
// moves the element into `resp`. Returns an error if `element` is of unexpected
// size, type, or shape.
Status MoveCompressedElement(std::vector<Tensor>&& element,
GetElementResponse& resp) {
if (element.size() != 1) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but the "
"dataset produced ",
element.size(), " outputs");
}
if (element[0].dtype() != DT_VARIANT) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with type ",
DataTypeString(element[0].dtype()));
}
if (!TensorShapeUtils::IsScalar(element[0].shape())) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with shape ",
element[0].shape());
}
Variant& variant = element[0].scalar<Variant>()();
CompressedElement* compressed = variant.get<CompressedElement>();
if (compressed == nullptr) {
return errors::FailedPrecondition(
"Expected dataset to produce a CompressedElement variant tensor, but "
"it produced ",
variant.TypeName());
}
*resp.mutable_compressed_element() = *compressed;
return Status::OK();
}
} // namespace } // namespace
StandaloneTaskIterator::StandaloneTaskIterator( StandaloneTaskIterator::StandaloneTaskIterator(
@ -71,62 +108,91 @@ FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
std::unique_ptr<TaskIterator> iterator) std::unique_ptr<TaskIterator> iterator)
: iterator_(std::move(iterator)) {} : iterator_(std::move(iterator)) {}
Status FirstComeFirstServedTaskRunner::GetNext(const Request& request, Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
std::vector<Tensor>& element, GetElementResponse& resp) {
bool& end_of_task) { std::vector<Tensor> element;
return iterator_->GetNext(element, end_of_task); bool end_of_task;
resp.set_skip_task(false);
TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
resp.set_end_of_sequence(end_of_task);
if (!end_of_task) {
return MoveCompressedElement(std::move(element), resp);
}
return Status::OK();
} }
RoundRobinTaskRunner::RoundRobinTaskRunner( RoundRobinTaskRunner::RoundRobinTaskRunner(
std::unique_ptr<TaskIterator> iterator, int64 num_consumers) std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
: num_consumers_(num_consumers), : num_consumers_(num_consumers),
iterator_(std::move(iterator)), buffer_(num_consumers_),
buffer_(num_consumers_) { prefetch_thread_(std::move(iterator), num_consumers_) {
VLOG(1) << "Creating task runner for distributing data round-robin to " VLOG(1) << "Creating task runner for distributing data round-robin to "
<< num_consumers << " consumers"; << num_consumers << " consumers";
} }
Status RoundRobinTaskRunner::GetNext(const Request& request, Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
std::vector<Tensor>& element, if (req.consumer_index() < 0 || req.round_index() < 0) {
bool& end_of_task) {
if (request.consumer_index < 0 || request.round_index < 0) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"RoundRobinTaskRunner needs to know the consumer index and element " "RoundRobinTaskRunner needs to know the consumer index and element "
"index of each request."); "index of each request.");
} }
if (request.consumer_index >= num_consumers_) { if (req.consumer_index() >= num_consumers_) {
return errors::FailedPrecondition( return errors::FailedPrecondition(
"Requesting data for consumer index ", request.consumer_index, "Requesting data for consumer index ", req.consumer_index(),
", but the task is configured for only ", num_consumers_, " consumers"); ", but the task is configured for only ", num_consumers_, " consumers");
} }
VLOG(2) << "Received request from consumer index " << request.consumer_index return Status::OK();
<< " for round " << request.round_index; }
{
mutex_lock l(mu_); Status RoundRobinTaskRunner::PrepareFullRound(int64 wait_us)
absl::flat_hash_set<int64>& round = requests_[request.round_index]; TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
first_round_ = std::min(first_round_, request.round_index); VLOG(1) << "Preparing full round for index " << current_round_;
round.insert(request.consumer_index);
if (current_round_ < request.round_index &&
round.size() == num_consumers_) {
VLOG(1) << "Starting normal round with round index "
<< request.round_index;
// This was the last request to arrive, time to start a new round. // This was the last request to arrive, time to start a new round.
TF_RETURN_IF_ERROR(FillBuffer()); TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
VLOG(1) << "Finished preparing data for round " << request.round_index; round_skipped_ = buffer_.empty();
current_round_ = request.round_index;
new_round_cv_.notify_all(); new_round_cv_.notify_all();
return Status::OK();
}
Status RoundRobinTaskRunner::PreparePartialRound()
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
<< " consumers";
current_round_ = first_round_;
new_round_cv_.notify_all();
// Indicates that we need a partial round to get consumers back in sync.
auto next_round_request = *(requests_[first_round_ + 1].begin());
if (next_round_request->skipped_previous_round()) {
VLOG(1) << "Skipping partial round";
round_skipped_ = true;
return Status::OK();
}
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
round_skipped_ = false;
return Status::OK();
}
Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
mutex_lock l(mu_);
absl::flat_hash_set<const GetElementRequest*>& round =
requests_[req.round_index()];
first_round_ = std::min(first_round_, req.round_index());
round.insert(&req);
if (current_round_ < req.round_index() && round.size() == num_consumers_) {
current_round_ = req.round_index();
int64 wait_us = kWaitBeforeSkipUs;
if (!req.allow_skip()) {
wait_us = -1;
}
TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
} }
if (current_round_ < 0 && if (current_round_ < 0 &&
requests_[first_round_].size() + requests_[first_round_ + 1].size() == requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
num_consumers_) { num_consumers_) {
VLOG(1) << "Starting partial round for " << requests_[first_round_].size() TF_RETURN_IF_ERROR(PreparePartialRound());
<< " consumers";
// Indicates that we need a partial round to get consumers back in sync.
TF_RETURN_IF_ERROR(FillBuffer());
current_round_ = first_round_;
new_round_cv_.notify_all();
} }
while (current_round_ < request.round_index) { while (current_round_ < req.round_index()) {
TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
std::cv_status s = std::cv_status s =
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs)); new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
if (s == std::cv_status::timeout) { if (s == std::cv_status::timeout) {
@ -135,33 +201,118 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
"Timeout waiting for other round-robin consumers to be ready."); "Timeout waiting for other round-robin consumers to be ready.");
} }
} }
end_of_task = end_of_task_; return prefetch_thread_.GetStatus();
} }
if (!end_of_task) {
element.clear(); Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
GetElementResponse& resp) {
TF_RETURN_IF_ERROR(ValidateRequest(req));
resp.set_end_of_sequence(false);
VLOG(2) << "Received request from consumer index " << req.consumer_index()
<< " for round " << req.round_index();
TF_RETURN_IF_ERROR(PrepareRound(req));
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
for (auto& component : buffer_[request.consumer_index]) { resp.set_skip_task(round_skipped_);
if (round_skipped_) {
VLOG(1) << "Buffer not ready, skipping round " << current_round_
<< " for consumer " << req.consumer_index();
return Status::OK();
}
std::vector<Tensor> element;
for (auto& component : buffer_[req.consumer_index()]) {
element.push_back(tensor::DeepCopy(component)); element.push_back(tensor::DeepCopy(component));
} }
if (VLOG_IS_ON(2)) {
int64 size = 0;
for (auto& component : element) {
size += component.TotalBytes();
} }
VLOG(2) << "Returning to consumer " << request.consumer_index << " for round " VLOG(2) << "Returning to consumer " << req.consumer_index() << " for round "
<< request.round_index; << req.round_index() << ". element size " << size;
}
return MoveCompressedElement(std::move(element), resp);
}
PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
int64 round_size)
: iterator_(std::move(iterator)), round_size_(round_size) {
thread_ = absl::WrapUnique(
Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
}
PrefetchThread::~PrefetchThread() {
mutex_lock l(mu_);
cancelled_ = true;
cv_.notify_all();
}
void PrefetchThread::Run() {
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= round_size_) {
cv_.wait(l);
}
if (cancelled_) {
return;
}
}
std::vector<Tensor> element;
bool end_of_sequence;
Status s = iterator_->GetNext(element, end_of_sequence);
if (!s.ok()) {
mutex_lock l(mu_);
status_ = s;
cv_.notify_all();
return;
}
if (end_of_sequence) {
mutex_lock l(mu_);
status_ = errors::FailedPrecondition(
"Encountered end of sequence on a round-robin read iterator. "
"Please ensure that the dataset used for round-robin reading has "
"infinite cardinality, e.g. by adding a .repeat() transformation "
"at the end.");
cv_.notify_all();
return;
}
mutex_lock l(mu_);
buffer_.push_back(std::move(element));
cv_.notify_all();
}
}
Status PrefetchThread::FillBuffer(int64 wait_us,
std::vector<std::vector<Tensor>>& out) {
int64 start_us = Env::Default()->NowMicros();
out.clear();
mutex_lock l(mu_);
while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
int64 remaining_us = start_us + wait_us - Env::Default()->NowMicros();
if (wait_us >= 0 && remaining_us <= 0) {
break;
}
cv_.wait_for(l, std::chrono::microseconds(remaining_us));
}
TF_RETURN_IF_ERROR(status_);
if (cancelled_) {
return errors::Cancelled("Prefetch thread cancelled");
}
if (buffer_.size() < round_size_) {
DCHECK_GE(wait_us, 0);
return Status::OK();
}
for (auto& elem : buffer_) {
out.push_back(std::move(elem));
}
buffer_.clear();
cv_.notify_all();
return Status::OK(); return Status::OK();
} }
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status PrefetchThread::GetStatus() {
for (int i = 0; i < num_consumers_; ++i) { mutex_lock l(mu_);
buffer_[i].clear(); return status_;
bool end_of_sequence;
TF_RETURN_IF_ERROR(iterator_->GetNext(buffer_[i], end_of_sequence));
if (end_of_sequence) {
return errors::FailedPrecondition(
"Encountered end of sequence on a round-robin read iterator. Please "
"ensure that the dataset used for round-robin reading has infinite "
"cardinality, e.g. by adding a .repeat() transformation at the end.");
}
}
return Status::OK();
} }
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -16,6 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_ #define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
#include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/data/standalone.h" #include "tensorflow/core/data/standalone.h"
#include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/status.h"
@ -54,28 +55,14 @@ class StandaloneTaskIterator : public TaskIterator {
// Interface for providing elements to task consumers. // Interface for providing elements to task consumers.
class TaskRunner { class TaskRunner {
public: public:
struct Request {
// Optional consumer index indicating which consumer is making the request.
// Only needed for round-robin reads.
int64 consumer_index = -1;
// Optional round index indicating which round the consumer wants to read
// from. Consumers are expected to read from consecutive rounds, starting
// with round 0. The task runner will attempt to serve all consumer
// requests for a round from the same block of `num_consumers` iterator
// indices, where block `n` is defined as elements `n*num_consumers` to
// `(n+1)*num_consumers`.
int64 round_index = -1;
};
// Creates a `TaskRunner` and stores it in `out`. // Creates a `TaskRunner` and stores it in `out`.
static Status Create(const TaskDef& task_def, static Status Create(const TaskDef& task_def,
std::unique_ptr<TaskIterator> iterator, std::unique_ptr<TaskIterator> iterator,
std::unique_ptr<TaskRunner>& out); std::unique_ptr<TaskRunner>& out);
virtual ~TaskRunner() = default; virtual ~TaskRunner() = default;
// Gets the next element for the given request, storing the results in // Gets the next element for the given request.
// `element` and `end_of_task`. virtual Status GetNext(const GetElementRequest& req,
virtual Status GetNext(const Request& request, std::vector<Tensor>& element, GetElementResponse& resp) = 0;
bool& end_of_task) = 0;
}; };
// A task runner which provides elements on a first-come first-served basis. // A task runner which provides elements on a first-come first-served basis.
@ -84,20 +71,52 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
public: public:
explicit FirstComeFirstServedTaskRunner( explicit FirstComeFirstServedTaskRunner(
std::unique_ptr<TaskIterator> iterator); std::unique_ptr<TaskIterator> iterator);
Status GetNext(const Request& request, std::vector<Tensor>& element, Status GetNext(const GetElementRequest& req,
bool& end_of_task) override; GetElementResponse& resp) override;
private: private:
std::unique_ptr<TaskIterator> iterator_; std::unique_ptr<TaskIterator> iterator_;
}; };
// Thread for prefetching a round worth of elements.
class PrefetchThread {
public:
explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
int64 round_size);
~PrefetchThread();
// Runs the prefetch thread. It runs until an error is encountered or the
// destructor is called.
void Run();
// Fills `out` with a round of data. Waits for up to `wait_us` micoseconds
// before giving up and returning with `out` empty. A negative `wait_us`
// signals to wait indefinitely.
Status FillBuffer(int64 wait_us, std::vector<std::vector<Tensor>>& out);
// Returns the status for any failures encountered by the prefetch thread.
Status GetStatus();
private:
const std::unique_ptr<TaskIterator> iterator_;
const int64 round_size_;
mutex mu_;
// Buffered results for the next round.
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
// The status if the prefetch thread fails.
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
// Thread which constantly tries to fill `buffer_` up with
// `num_consumers` elements.
std::unique_ptr<Thread> thread_;
// Condition variable notified when elements are added to or removed from
// `buffer_`, or when `status_` is changed.
condition_variable cv_;
bool cancelled_ TF_GUARDED_BY(mu_) = false;
};
// A task runner which enforces round-robin order for consuming a task's // A task runner which enforces round-robin order for consuming a task's
// elements. Requests must provide a consumer index and element index. // elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
// `RoundRobinTaskRunner` provides elements in a series of "rounds". In each // In each successive round, the runner waits to receive requests from all
// successive round, the runner waits to receive requests from all consumers. // consumers. These requests are blocked until all requests arrive. Once all
// These requests are blocked until all requests arrive. Once all requests // requests arrive, the runner hands out elements to consumers in order of their
// arrive, the runner hands out elements to consumers in order of their consumer // consumer indices.
// indices.
// //
// Consumers are expected to successively request consecutive element indices, // Consumers are expected to successively request consecutive element indices,
// starting at 0. The same element can be requested multiple times by the same // starting at 0. The same element can be requested multiple times by the same
@ -113,28 +132,37 @@ class RoundRobinTaskRunner : public TaskRunner {
public: public:
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator, RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
int64 num_consumers); int64 num_consumers);
Status GetNext(const Request& request, std::vector<Tensor>& element,
bool& end_of_task) override; Status GetNext(const GetElementRequest& req,
GetElementResponse& resp) override;
private: private:
// Fills `buffer_` with `num_consumers_` elements. // Prepares a full round of data. `wait_us` indicates how long to wait before
Status FillBuffer(); // skipping if a full round of data is not yet ready.
Status PrepareFullRound(int64 wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Prepares a partial round to get consumers back in sync.
Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
Status ValidateRequest(const GetElementRequest& req);
// Prepares data for the next round, blocking until the round is ready to
// start.
Status PrepareRound(const GetElementRequest& req);
const int64 num_consumers_; const int64 num_consumers_;
std::unique_ptr<TaskIterator> iterator_;
mutex mu_; mutex mu_;
// Condition variable notified whenever we start a new round of round-robin. // Condition variable notified whenever we start a new round of round-robin.
condition_variable new_round_cv_; condition_variable new_round_cv_;
// Map from round number to consumers waiting for data from that round. // Map from round number to requests waiting for data from that round.
absl::flat_hash_map<int64, absl::flat_hash_set<int64>> requests_ absl::flat_hash_map<int64, absl::flat_hash_set<const GetElementRequest*>>
TF_GUARDED_BY(mu_); requests_ TF_GUARDED_BY(mu_);
// Index of the first round we plan to serve. At startup, this is the minimum // Index of the first round we plan to serve. At startup, this is the minimum
// of all requested element indices. // of all requested element indices.
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max; int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
int64 current_round_ TF_GUARDED_BY(mu_) = -1; int64 current_round_ TF_GUARDED_BY(mu_) = -1;
bool round_skipped_ TF_GUARDED_BY(mu_) = false;
// Buffered results for the current round. // Buffered results for the current round.
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_); std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
bool end_of_task_ TF_GUARDED_BY(mu_) = false; // Thread which constantly tries to prepare `num_consumers` elements for the
// next round.
PrefetchThread prefetch_thread_;
}; };
} // namespace data } // namespace data

View File

@ -13,6 +13,9 @@ limitations under the License.
#include "tensorflow/core/data/service/task_runner.h" #include "tensorflow/core/data/service/task_runner.h"
#include "absl/memory/memory.h" #include "absl/memory/memory.h"
#include "tensorflow/core/data/compression_utils.h"
#include "tensorflow/core/data/dataset.pb.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/status_test_util.h"
@ -31,7 +34,10 @@ class TestTaskIterator : public TaskIterator {
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override { Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
end_of_sequence = index_ >= elements_.size(); end_of_sequence = index_ >= elements_.size();
if (!end_of_sequence) { if (!end_of_sequence) {
element = elements_[index_]; CompressedElement compressed;
TF_RETURN_IF_ERROR(CompressElement(elements_[index_], &compressed));
element.emplace_back(DT_VARIANT, TensorShape({}));
element[0].scalar<Variant>()() = std::move(compressed);
index_ = (index_ + 1) % elements_.size(); index_ = (index_ + 1) % elements_.size();
} }
return Status::OK(); return Status::OK();
@ -47,16 +53,22 @@ class TestTaskIterator : public TaskIterator {
// Reads from the task runner, storing results in `*output`. // Reads from the task runner, storing results in `*output`.
Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index, Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index,
TaskRunner& task_runner, std::vector<int64>& output) { TaskRunner& task_runner, std::vector<int64>& output) {
bool end_of_sequence = false;
for (int64 next_index = start_index; next_index < end_index; ++next_index) { for (int64 next_index = start_index; next_index < end_index; ++next_index) {
TaskRunner::Request request; GetElementRequest request;
request.round_index = next_index; request.set_round_index(next_index);
request.consumer_index = consumer_index; request.set_consumer_index(consumer_index);
std::vector<Tensor> element; request.set_skipped_previous_round(false);
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence)); request.set_allow_skip(false);
if (!end_of_sequence) { GetElementResponse response;
output.push_back(element[0].flat<int64>()(0)); do {
TF_RETURN_IF_ERROR(task_runner.GetNext(request, response));
if (!response.end_of_sequence()) {
std::vector<Tensor> uncompressed;
TF_RETURN_IF_ERROR(
UncompressElement(response.compressed_element(), &uncompressed));
output.push_back(uncompressed[0].flat<int64>()(0));
} }
} while (response.skip_task());
} }
return Status::OK(); return Status::OK();
} }
@ -71,12 +83,13 @@ TEST(FirstComeFirstServedTaskRunner, GetNext) {
} }
FirstComeFirstServedTaskRunner runner( FirstComeFirstServedTaskRunner runner(
absl::make_unique<TestTaskIterator>(elements)); absl::make_unique<TestTaskIterator>(elements));
TaskRunner::Request request; GetElementRequest request;
GetElementResponse response;
for (auto& expected_element : elements) { for (auto& expected_element : elements) {
TF_ASSERT_OK(runner.GetNext(request, response));
ASSERT_FALSE(response.end_of_sequence());
std::vector<Tensor> element; std::vector<Tensor> element;
bool end_of_sequence; TF_ASSERT_OK(UncompressElement(response.compressed_element(), &element));
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
ASSERT_FALSE(end_of_sequence);
ASSERT_EQ(element.size(), 1); ASSERT_EQ(element.size(), 1);
test::ExpectEqual(element[0], expected_element[0]); test::ExpectEqual(element[0], expected_element[0]);
} }

View File

@ -23,6 +23,11 @@ message GetElementRequest {
oneof optional_round_index { oneof optional_round_index {
int64 round_index = 3; int64 round_index = 3;
} }
// Whether the previous round was skipped. This information is needed by the
// worker to recover after restarts.
bool skipped_previous_round = 4;
// Whether to skip the round if data isn't ready fast enough.
bool allow_skip = 5;
} }
message GetElementResponse { message GetElementResponse {
@ -30,6 +35,8 @@ message GetElementResponse {
CompressedElement compressed_element = 3; CompressedElement compressed_element = 3;
// Boolean to indicate whether the iterator has been exhausted. // Boolean to indicate whether the iterator has been exhausted.
bool end_of_sequence = 2; bool end_of_sequence = 2;
// Indicates whether the round was skipped.
bool skip_task = 4;
} }
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto // Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto

View File

@ -180,8 +180,6 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request, Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
GetElementResponse* response) { GetElementResponse* response) {
VLOG(3) << "Received GetElement request for task " << request->task_id(); VLOG(3) << "Received GetElement request for task " << request->task_id();
bool end_of_sequence = false;
std::vector<tensorflow::Tensor> outputs;
Task* task; Task* task;
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
@ -207,54 +205,15 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
task = it->second.get(); task = it->second.get();
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task)); TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
} }
TaskRunner::Request get_next_request; TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *response));
if (request->optional_consumer_index_case() == if (response->end_of_sequence()) {
GetElementRequest::kConsumerIndex) {
get_next_request.consumer_index = request->consumer_index();
}
if (request->optional_round_index_case() == GetElementRequest::kRoundIndex) {
get_next_request.round_index = request->round_index();
}
TF_RETURN_IF_ERROR(
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
if (end_of_sequence) {
mutex_lock l(mu_); mutex_lock l(mu_);
VLOG(3) << "Reached end_of_sequence for task " << request->task_id(); VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
pending_completed_tasks_.insert(request->task_id()); pending_completed_tasks_.insert(request->task_id());
task_completion_cv_.notify_one(); task_completion_cv_.notify_one();
} } else if (!response->skip_task()) {
if (!end_of_sequence) {
VLOG(3) << "Producing an element for task " << request->task_id(); VLOG(3) << "Producing an element for task " << request->task_id();
if (outputs.size() != 1) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but the "
"dataset produced ",
outputs.size(), " outputs");
} }
if (outputs[0].dtype() != DT_VARIANT) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with type ",
DataTypeString(outputs[0].dtype()));
}
if (!TensorShapeUtils::IsScalar(outputs[0].shape())) {
return errors::FailedPrecondition(
"Expected dataset to produce a single scalar variant tensor, but "
"the dataset produced a tensor with shape ",
outputs[0].shape());
}
Variant& variant = outputs[0].scalar<Variant>()();
CompressedElement* compressed = variant.get<CompressedElement>();
if (compressed == nullptr) {
return errors::FailedPrecondition(
"Expected dataset to produce a CompressedElement variant tensor, but "
"it produced ",
variant.TypeName());
}
*response->mutable_compressed_element() = *compressed;
}
response->set_end_of_sequence(end_of_sequence);
return Status::OK(); return Status::OK();
} }

View File

@ -168,6 +168,7 @@ tf_kernel_library(
"//tensorflow/core/data/service:data_service", "//tensorflow/core/data/service:data_service",
"//tensorflow/core/data/service:dispatcher_proto_cc", "//tensorflow/core/data/service:dispatcher_proto_cc",
"//tensorflow/core/data/service:grpc_util", "//tensorflow/core/data/service:grpc_util",
"//tensorflow/core/data/service:worker_proto_cc",
"//tensorflow/core/distributed_runtime/rpc:grpc_util", "//tensorflow/core/distributed_runtime/rpc:grpc_util",
"//tensorflow/core/kernels/data:dataset_utils", "//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:name_utils", "//tensorflow/core/kernels/data:name_utils",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/data_service.h"
#include "tensorflow/core/data/service/dispatcher.pb.h" #include "tensorflow/core/data/service/dispatcher.pb.h"
#include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/worker.pb.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/model.h"
@ -295,6 +296,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
}); });
} }
bool skip = true;
while (skip) {
while ((results_.empty() || !results_.front().ready) && while ((results_.empty() || !results_.front().ready) &&
!(job_finished_ && num_running_worker_threads_ == 0) && !(job_finished_ && num_running_worker_threads_ == 0) &&
!cancelled_ && status_.ok()) { !cancelled_ && status_.ok()) {
@ -319,6 +322,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
VLOG(3) << "Returning from GetNext with end_of_sequence"; VLOG(3) << "Returning from GetNext with end_of_sequence";
return Status::OK(); return Status::OK();
} }
skip = results_.front().skip;
if (skip) {
results_.pop();
worker_thread_cv_.notify_one();
}
}
*end_of_sequence = results_.front().end_of_sequence; *end_of_sequence = results_.front().end_of_sequence;
if (!*end_of_sequence) { if (!*end_of_sequence) {
out_tensors->swap(results_.front().element); out_tensors->swap(results_.front().element);
@ -378,6 +387,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
const std::unique_ptr<DataServiceWorkerClient> worker; const std::unique_ptr<DataServiceWorkerClient> worker;
// The next round to read from the task. // The next round to read from the task.
int64 round = 0; int64 round = 0;
bool skipped_previous_round = false;
// Indicates whether a worker thread is currently processing the task. // Indicates whether a worker thread is currently processing the task.
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false; bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
// Indicates whether the worker has returned end_of_sequence for the task. // Indicates whether the worker has returned end_of_sequence for the task.
@ -390,6 +400,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
bool ready TF_GUARDED_BY(&Iterator::mu_) = false; bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_); std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false; bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
bool skip TF_GUARDED_BY(&Iterator::mu_) = false;
}; };
// Periodically refresh the task list. // Periodically refresh the task list.
@ -677,25 +688,24 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return profiler::TraceMeEncode( return profiler::TraceMeEncode(
{{"address", task->info.worker_address()}}); {{"address", task->info.worker_address()}});
}); });
CompressedElement compressed; GetElementResponse resp;
bool end_of_sequence;
for (int num_retries = 0;; ++num_retries) { for (int num_retries = 0;; ++num_retries) {
absl::optional<int64> consumer_index = dataset()->consumer_index_; GetElementRequest req;
absl::optional<int64> round_index;
if (StrictRoundRobin()) { if (StrictRoundRobin()) {
round_index = task->round; req.set_consumer_index(dataset()->consumer_index_.value());
req.set_round_index(task->round);
req.set_allow_skip(true);
VLOG(3) << "Requesting element from consumer index " VLOG(3) << "Requesting element from consumer index "
<< consumer_index.value() << ", round " << req.consumer_index() << ", round " << req.round_index();
<< round_index.value();
activity.AppendMetadata([&]() { activity.AppendMetadata([&]() {
return profiler::TraceMeEncode( return profiler::TraceMeEncode(
{{"consumer_index", consumer_index.value()}, {{"consumer_index", req.consumer_index()},
{"round_index", round_index.value()}}); {"round_index", req.round_index()}});
}); });
} }
Status s = req.set_task_id(task->info.task_id());
task->worker->GetElement(task->info.task_id(), consumer_index, req.set_skipped_previous_round(task->skipped_previous_round);
round_index, compressed, end_of_sequence); Status s = task->worker->GetElement(req, resp);
if (s.ok()) { if (s.ok()) {
break; break;
} }
@ -709,7 +719,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
// If `UpdateTaskThreads` finds that the task has been cancelled, it // If `UpdateTaskThreads` finds that the task has been cancelled, it
// will set end_of_sequence to `true`. // will set end_of_sequence to `true`.
if (task->end_of_sequence || cancelled_) { if (task->end_of_sequence || cancelled_) {
end_of_sequence = true;
break; break;
} }
} }
@ -734,21 +743,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
} }
std::vector<Tensor> element; std::vector<Tensor> element;
if (!end_of_sequence) { if (resp.has_compressed_element()) {
Tensor tensor(DT_VARIANT, TensorShape{}); Tensor tensor(DT_VARIANT, TensorShape{});
tensor.scalar<Variant>()() = std::move(compressed); tensor.scalar<Variant>()() = std::move(resp.compressed_element());
element.push_back(tensor); element.push_back(tensor);
} }
mutex_lock l(mu_); mutex_lock l(mu_);
result.ready = true; result.ready = true;
result.end_of_sequence = end_of_sequence; result.end_of_sequence = resp.end_of_sequence();
if (end_of_sequence) { if (resp.has_compressed_element()) {
task->skipped_previous_round = false;
task->round++;
result.element = std::move(element);
} else if (resp.skip_task()) {
task->skipped_previous_round = true;
task->round++;
result.skip = true;
} else {
task->end_of_sequence = true; task->end_of_sequence = true;
finished_tasks_++; finished_tasks_++;
return Status::OK();
} }
result.element = std::move(element); if (enqueue_result && !resp.end_of_sequence()) {
if (enqueue_result) {
results_.push(std::move(result)); results_.push(std::move(result));
} }
get_next_cv_.notify_all(); get_next_cv_.notify_all();

View File

@ -291,8 +291,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
cluster = self.create_cluster(num_workers=num_workers) cluster = self.create_cluster(num_workers=num_workers)
# Round robin reads can cause slow cluster shutdown. # Round robin reads can cause slow cluster shutdown.
data_service_test_base.GLOBAL_CLUSTERS.add(cluster) data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
num_elements = 100 ds = dataset_ops.Dataset.range(10000000)
ds = dataset_ops.Dataset.range(num_elements)
ds = ds.repeat() ds = ds.repeat()
consumers = [] consumers = []
for consumer_index in range(num_consumers): for consumer_index in range(num_consumers):
@ -309,18 +308,15 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
lambda x: x, lambda x: x,
cycle_length=num_consumers, cycle_length=num_consumers,
num_parallel_calls=num_consumers) num_parallel_calls=num_consumers)
ds = ds.take(2 * num_elements * num_workers) ds = ds.take(1000)
results = self.getDatasetOutput(ds, requires_initialization=True) results = self.getDatasetOutput(ds, requires_initialization=True)
expected = [] for i in range(0, len(results), num_consumers):
round_index = 0 self.assertEqual(0, results[i] % num_consumers)
while len(expected) < len(results): # Check that each group of `num_consumers` results are consecutive.
for _ in range(num_workers): for offset in range(1, num_consumers):
for consumer in range(num_consumers): if i + offset < len(results):
expected.append( self.assertEqual(results[i] + offset, results[i + offset])
(round_index * num_consumers + consumer) % num_elements)
round_index += 1
self.assertEqual(results, expected)
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def testRoundRobinBucketizing(self): def testRoundRobinBucketizing(self):