[tf.data service] Skip round robin rounds when data isn't ready.
PiperOrigin-RevId: 356348715 Change-Id: I8ac227a098d49bd8a3fd6c96b93ac855df80a121
This commit is contained in:
parent
0092ebe4c7
commit
2cc0ab3c0c
@ -57,6 +57,7 @@ tf_proto_library(
|
|||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
":data_transfer_visibility",
|
":data_transfer_visibility",
|
||||||
|
"//tensorflow:internal",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,6 +96,7 @@ cc_library(
|
|||||||
":dispatcher_cc_grpc_proto",
|
":dispatcher_cc_grpc_proto",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":worker_cc_grpc_proto",
|
":worker_cc_grpc_proto",
|
||||||
|
":worker_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@com_google_absl//absl/container:flat_hash_set",
|
||||||
@ -415,6 +417,7 @@ cc_library(
|
|||||||
hdrs = ["task_runner.h"],
|
hdrs = ["task_runner.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
|
":worker_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core/data:compression_utils",
|
"//tensorflow/core/data:compression_utils",
|
||||||
@ -428,11 +431,14 @@ tf_cc_test(
|
|||||||
srcs = ["task_runner_test.cc"],
|
srcs = ["task_runner_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":task_runner",
|
":task_runner",
|
||||||
|
":worker_proto_cc",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
"//tensorflow/core:testlib",
|
"//tensorflow/core:testlib",
|
||||||
|
"//tensorflow/core/data:compression_utils",
|
||||||
|
"//tensorflow/core/data:dataset_proto_cc",
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -248,25 +248,14 @@ class GrpcDataTransferClient : public DataTransferClient {
|
|||||||
stub_ = WorkerService::NewStub(channel);
|
stub_ = WorkerService::NewStub(channel);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
Status GetElement(const GetElementRequest& req,
|
||||||
absl::optional<int64> round_index,
|
GetElementResponse& resp) override {
|
||||||
CompressedElement& element,
|
|
||||||
bool& end_of_sequence) override {
|
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (cancelled_) {
|
if (cancelled_) {
|
||||||
return errors::Cancelled("Client was cancelled.");
|
return errors::Cancelled("Client was cancelled.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GetElementRequest req;
|
|
||||||
req.set_task_id(task_id);
|
|
||||||
if (consumer_index.has_value()) {
|
|
||||||
req.set_consumer_index(consumer_index.value());
|
|
||||||
}
|
|
||||||
if (round_index.has_value()) {
|
|
||||||
req.set_round_index(round_index.value());
|
|
||||||
}
|
|
||||||
GetElementResponse resp;
|
|
||||||
grpc::ClientContext ctx;
|
grpc::ClientContext ctx;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -280,10 +269,6 @@ class GrpcDataTransferClient : public DataTransferClient {
|
|||||||
if (!s.ok()) {
|
if (!s.ok()) {
|
||||||
return grpc_util::WrapError("Failed to get element", s);
|
return grpc_util::WrapError("Failed to get element", s);
|
||||||
}
|
}
|
||||||
end_of_sequence = resp.end_of_sequence();
|
|
||||||
if (!end_of_sequence) {
|
|
||||||
element = std::move(*resp.mutable_compressed_element());
|
|
||||||
}
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -324,14 +309,10 @@ class GrpcTransferClientRegistrar {
|
|||||||
};
|
};
|
||||||
static GrpcTransferClientRegistrar registrar;
|
static GrpcTransferClientRegistrar registrar;
|
||||||
|
|
||||||
Status DataServiceWorkerClient::GetElement(int64 task_id,
|
Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
|
||||||
absl::optional<int64> consumer_index,
|
GetElementResponse& resp) {
|
||||||
absl::optional<int64> round_index,
|
|
||||||
CompressedElement& element,
|
|
||||||
bool& end_of_sequence) {
|
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
return client_->GetElement(task_id, consumer_index, round_index, element,
|
return client_->GetElement(req, resp);
|
||||||
end_of_sequence);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceWorkerClient::EnsureInitialized() {
|
Status DataServiceWorkerClient::EnsureInitialized() {
|
||||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/data/service/data_transfer.h"
|
#include "tensorflow/core/data/service/data_transfer.h"
|
||||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/worker.pb.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
|
|
||||||
@ -144,14 +145,8 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
|||||||
: DataServiceClientBase(address, protocol),
|
: DataServiceClientBase(address, protocol),
|
||||||
transfer_protocol_(transfer_protocol) {}
|
transfer_protocol_(transfer_protocol) {}
|
||||||
|
|
||||||
// Fetches the next element for the specified task_id. The optional
|
// Fetches an element from the worker.
|
||||||
// `consumer_index` and `round_index` must be specified for tasks which use
|
Status GetElement(const GetElementRequest& req, GetElementResponse& resp);
|
||||||
// round-robin ordering. The element's compressed tensors will be stored in
|
|
||||||
// `element`. If no element is available, `end_of_sequence` will be `true`,
|
|
||||||
// and `element` will be left unchanged.
|
|
||||||
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
|
||||||
absl::optional<int64> round_index,
|
|
||||||
CompressedElement& element, bool& end_of_sequence);
|
|
||||||
|
|
||||||
// Makes a best effort to cancel all outstanding calls in progress for the
|
// Makes a best effort to cancel all outstanding calls in progress for the
|
||||||
// client, and causes further calls to return Cancelled status.
|
// client, and causes further calls to return Cancelled status.
|
||||||
|
@ -38,13 +38,9 @@ class DataTransferClient {
|
|||||||
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
|
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
|
||||||
virtual ~DataTransferClient() = default;
|
virtual ~DataTransferClient() = default;
|
||||||
|
|
||||||
// Fetches the next element for the specified task_id. The element's
|
// Fetches the next element.
|
||||||
// compressed tensors will be stored in `element`. If no element is available,
|
virtual Status GetElement(const GetElementRequest& req,
|
||||||
// `end_of_sequence` will be `true`, and `element` will be left unchanged.
|
GetElementResponse& resp) = 0;
|
||||||
virtual Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
|
||||||
absl::optional<int64> round_index,
|
|
||||||
tensorflow::data::CompressedElement& element,
|
|
||||||
bool& end_of_sequence) = 0;
|
|
||||||
|
|
||||||
// Makes a best effort to cancel all outstanding calls in progress for the
|
// Makes a best effort to cancel all outstanding calls in progress for the
|
||||||
// client, and causes further calls to return Cancelled status.
|
// client, and causes further calls to return Cancelled status.
|
||||||
|
@ -29,6 +29,43 @@ namespace {
|
|||||||
// Unavailable error. This prevents the server from hanging on shutdown when
|
// Unavailable error. This prevents the server from hanging on shutdown when
|
||||||
// some round-robin consumers exit earlier than others.
|
// some round-robin consumers exit earlier than others.
|
||||||
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
|
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
|
||||||
|
// Time to wait before skipping a round if data still isn't available.
|
||||||
|
const int64 kWaitBeforeSkipUs = 100 * 1000; // 100ms.
|
||||||
|
|
||||||
|
// Interprets `element` as a size-1 vector containing a CompressedElement, and
|
||||||
|
// moves the element into `resp`. Returns an error if `element` is of unexpected
|
||||||
|
// size, type, or shape.
|
||||||
|
Status MoveCompressedElement(std::vector<Tensor>&& element,
|
||||||
|
GetElementResponse& resp) {
|
||||||
|
if (element.size() != 1) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Expected dataset to produce a single scalar variant tensor, but the "
|
||||||
|
"dataset produced ",
|
||||||
|
element.size(), " outputs");
|
||||||
|
}
|
||||||
|
if (element[0].dtype() != DT_VARIANT) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Expected dataset to produce a single scalar variant tensor, but "
|
||||||
|
"the dataset produced a tensor with type ",
|
||||||
|
DataTypeString(element[0].dtype()));
|
||||||
|
}
|
||||||
|
if (!TensorShapeUtils::IsScalar(element[0].shape())) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Expected dataset to produce a single scalar variant tensor, but "
|
||||||
|
"the dataset produced a tensor with shape ",
|
||||||
|
element[0].shape());
|
||||||
|
}
|
||||||
|
Variant& variant = element[0].scalar<Variant>()();
|
||||||
|
CompressedElement* compressed = variant.get<CompressedElement>();
|
||||||
|
if (compressed == nullptr) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Expected dataset to produce a CompressedElement variant tensor, but "
|
||||||
|
"it produced ",
|
||||||
|
variant.TypeName());
|
||||||
|
}
|
||||||
|
*resp.mutable_compressed_element() = *compressed;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StandaloneTaskIterator::StandaloneTaskIterator(
|
StandaloneTaskIterator::StandaloneTaskIterator(
|
||||||
@ -71,62 +108,91 @@ FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
|
|||||||
std::unique_ptr<TaskIterator> iterator)
|
std::unique_ptr<TaskIterator> iterator)
|
||||||
: iterator_(std::move(iterator)) {}
|
: iterator_(std::move(iterator)) {}
|
||||||
|
|
||||||
Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
|
Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
|
||||||
std::vector<Tensor>& element,
|
GetElementResponse& resp) {
|
||||||
bool& end_of_task) {
|
std::vector<Tensor> element;
|
||||||
return iterator_->GetNext(element, end_of_task);
|
bool end_of_task;
|
||||||
|
resp.set_skip_task(false);
|
||||||
|
TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
|
||||||
|
resp.set_end_of_sequence(end_of_task);
|
||||||
|
if (!end_of_task) {
|
||||||
|
return MoveCompressedElement(std::move(element), resp);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
RoundRobinTaskRunner::RoundRobinTaskRunner(
|
RoundRobinTaskRunner::RoundRobinTaskRunner(
|
||||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
|
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
|
||||||
: num_consumers_(num_consumers),
|
: num_consumers_(num_consumers),
|
||||||
iterator_(std::move(iterator)),
|
buffer_(num_consumers_),
|
||||||
buffer_(num_consumers_) {
|
prefetch_thread_(std::move(iterator), num_consumers_) {
|
||||||
VLOG(1) << "Creating task runner for distributing data round-robin to "
|
VLOG(1) << "Creating task runner for distributing data round-robin to "
|
||||||
<< num_consumers << " consumers";
|
<< num_consumers << " consumers";
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RoundRobinTaskRunner::GetNext(const Request& request,
|
Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
|
||||||
std::vector<Tensor>& element,
|
if (req.consumer_index() < 0 || req.round_index() < 0) {
|
||||||
bool& end_of_task) {
|
|
||||||
if (request.consumer_index < 0 || request.round_index < 0) {
|
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"RoundRobinTaskRunner needs to know the consumer index and element "
|
"RoundRobinTaskRunner needs to know the consumer index and element "
|
||||||
"index of each request.");
|
"index of each request.");
|
||||||
}
|
}
|
||||||
if (request.consumer_index >= num_consumers_) {
|
if (req.consumer_index() >= num_consumers_) {
|
||||||
return errors::FailedPrecondition(
|
return errors::FailedPrecondition(
|
||||||
"Requesting data for consumer index ", request.consumer_index,
|
"Requesting data for consumer index ", req.consumer_index(),
|
||||||
", but the task is configured for only ", num_consumers_, " consumers");
|
", but the task is configured for only ", num_consumers_, " consumers");
|
||||||
}
|
}
|
||||||
VLOG(2) << "Received request from consumer index " << request.consumer_index
|
return Status::OK();
|
||||||
<< " for round " << request.round_index;
|
}
|
||||||
{
|
|
||||||
mutex_lock l(mu_);
|
Status RoundRobinTaskRunner::PrepareFullRound(int64 wait_us)
|
||||||
absl::flat_hash_set<int64>& round = requests_[request.round_index];
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
first_round_ = std::min(first_round_, request.round_index);
|
VLOG(1) << "Preparing full round for index " << current_round_;
|
||||||
round.insert(request.consumer_index);
|
|
||||||
if (current_round_ < request.round_index &&
|
|
||||||
round.size() == num_consumers_) {
|
|
||||||
VLOG(1) << "Starting normal round with round index "
|
|
||||||
<< request.round_index;
|
|
||||||
// This was the last request to arrive, time to start a new round.
|
// This was the last request to arrive, time to start a new round.
|
||||||
TF_RETURN_IF_ERROR(FillBuffer());
|
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
|
||||||
VLOG(1) << "Finished preparing data for round " << request.round_index;
|
round_skipped_ = buffer_.empty();
|
||||||
current_round_ = request.round_index;
|
|
||||||
new_round_cv_.notify_all();
|
new_round_cv_.notify_all();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RoundRobinTaskRunner::PreparePartialRound()
|
||||||
|
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
|
||||||
|
<< " consumers";
|
||||||
|
current_round_ = first_round_;
|
||||||
|
new_round_cv_.notify_all();
|
||||||
|
// Indicates that we need a partial round to get consumers back in sync.
|
||||||
|
auto next_round_request = *(requests_[first_round_ + 1].begin());
|
||||||
|
if (next_round_request->skipped_previous_round()) {
|
||||||
|
VLOG(1) << "Skipping partial round";
|
||||||
|
round_skipped_ = true;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
|
||||||
|
round_skipped_ = false;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
absl::flat_hash_set<const GetElementRequest*>& round =
|
||||||
|
requests_[req.round_index()];
|
||||||
|
first_round_ = std::min(first_round_, req.round_index());
|
||||||
|
round.insert(&req);
|
||||||
|
if (current_round_ < req.round_index() && round.size() == num_consumers_) {
|
||||||
|
current_round_ = req.round_index();
|
||||||
|
int64 wait_us = kWaitBeforeSkipUs;
|
||||||
|
if (!req.allow_skip()) {
|
||||||
|
wait_us = -1;
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
|
||||||
}
|
}
|
||||||
if (current_round_ < 0 &&
|
if (current_round_ < 0 &&
|
||||||
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
|
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
|
||||||
num_consumers_) {
|
num_consumers_) {
|
||||||
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
|
TF_RETURN_IF_ERROR(PreparePartialRound());
|
||||||
<< " consumers";
|
|
||||||
// Indicates that we need a partial round to get consumers back in sync.
|
|
||||||
TF_RETURN_IF_ERROR(FillBuffer());
|
|
||||||
current_round_ = first_round_;
|
|
||||||
new_round_cv_.notify_all();
|
|
||||||
}
|
}
|
||||||
while (current_round_ < request.round_index) {
|
while (current_round_ < req.round_index()) {
|
||||||
|
TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
|
||||||
std::cv_status s =
|
std::cv_status s =
|
||||||
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
|
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
|
||||||
if (s == std::cv_status::timeout) {
|
if (s == std::cv_status::timeout) {
|
||||||
@ -135,33 +201,118 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
|
|||||||
"Timeout waiting for other round-robin consumers to be ready.");
|
"Timeout waiting for other round-robin consumers to be ready.");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
end_of_task = end_of_task_;
|
return prefetch_thread_.GetStatus();
|
||||||
}
|
}
|
||||||
if (!end_of_task) {
|
|
||||||
element.clear();
|
Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
|
||||||
|
GetElementResponse& resp) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateRequest(req));
|
||||||
|
resp.set_end_of_sequence(false);
|
||||||
|
VLOG(2) << "Received request from consumer index " << req.consumer_index()
|
||||||
|
<< " for round " << req.round_index();
|
||||||
|
TF_RETURN_IF_ERROR(PrepareRound(req));
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
for (auto& component : buffer_[request.consumer_index]) {
|
resp.set_skip_task(round_skipped_);
|
||||||
|
if (round_skipped_) {
|
||||||
|
VLOG(1) << "Buffer not ready, skipping round " << current_round_
|
||||||
|
<< " for consumer " << req.consumer_index();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
std::vector<Tensor> element;
|
||||||
|
for (auto& component : buffer_[req.consumer_index()]) {
|
||||||
element.push_back(tensor::DeepCopy(component));
|
element.push_back(tensor::DeepCopy(component));
|
||||||
}
|
}
|
||||||
|
if (VLOG_IS_ON(2)) {
|
||||||
|
int64 size = 0;
|
||||||
|
for (auto& component : element) {
|
||||||
|
size += component.TotalBytes();
|
||||||
}
|
}
|
||||||
VLOG(2) << "Returning to consumer " << request.consumer_index << " for round "
|
VLOG(2) << "Returning to consumer " << req.consumer_index() << " for round "
|
||||||
<< request.round_index;
|
<< req.round_index() << ". element size " << size;
|
||||||
|
}
|
||||||
|
return MoveCompressedElement(std::move(element), resp);
|
||||||
|
}
|
||||||
|
|
||||||
|
PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
|
||||||
|
int64 round_size)
|
||||||
|
: iterator_(std::move(iterator)), round_size_(round_size) {
|
||||||
|
thread_ = absl::WrapUnique(
|
||||||
|
Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
|
||||||
|
}
|
||||||
|
|
||||||
|
PrefetchThread::~PrefetchThread() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
cancelled_ = true;
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
|
||||||
|
void PrefetchThread::Run() {
|
||||||
|
while (true) {
|
||||||
|
{
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (!cancelled_ && buffer_.size() >= round_size_) {
|
||||||
|
cv_.wait(l);
|
||||||
|
}
|
||||||
|
if (cancelled_) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<Tensor> element;
|
||||||
|
bool end_of_sequence;
|
||||||
|
Status s = iterator_->GetNext(element, end_of_sequence);
|
||||||
|
if (!s.ok()) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
status_ = s;
|
||||||
|
cv_.notify_all();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (end_of_sequence) {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
status_ = errors::FailedPrecondition(
|
||||||
|
"Encountered end of sequence on a round-robin read iterator. "
|
||||||
|
"Please ensure that the dataset used for round-robin reading has "
|
||||||
|
"infinite cardinality, e.g. by adding a .repeat() transformation "
|
||||||
|
"at the end.");
|
||||||
|
cv_.notify_all();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
buffer_.push_back(std::move(element));
|
||||||
|
cv_.notify_all();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Status PrefetchThread::FillBuffer(int64 wait_us,
|
||||||
|
std::vector<std::vector<Tensor>>& out) {
|
||||||
|
int64 start_us = Env::Default()->NowMicros();
|
||||||
|
out.clear();
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
|
||||||
|
int64 remaining_us = start_us + wait_us - Env::Default()->NowMicros();
|
||||||
|
if (wait_us >= 0 && remaining_us <= 0) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
cv_.wait_for(l, std::chrono::microseconds(remaining_us));
|
||||||
|
}
|
||||||
|
TF_RETURN_IF_ERROR(status_);
|
||||||
|
if (cancelled_) {
|
||||||
|
return errors::Cancelled("Prefetch thread cancelled");
|
||||||
|
}
|
||||||
|
if (buffer_.size() < round_size_) {
|
||||||
|
DCHECK_GE(wait_us, 0);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
for (auto& elem : buffer_) {
|
||||||
|
out.push_back(std::move(elem));
|
||||||
|
}
|
||||||
|
buffer_.clear();
|
||||||
|
cv_.notify_all();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
Status PrefetchThread::GetStatus() {
|
||||||
for (int i = 0; i < num_consumers_; ++i) {
|
mutex_lock l(mu_);
|
||||||
buffer_[i].clear();
|
return status_;
|
||||||
bool end_of_sequence;
|
|
||||||
TF_RETURN_IF_ERROR(iterator_->GetNext(buffer_[i], end_of_sequence));
|
|
||||||
if (end_of_sequence) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
"Encountered end of sequence on a round-robin read iterator. Please "
|
|
||||||
"ensure that the dataset used for round-robin reading has infinite "
|
|
||||||
"cardinality, e.g. by adding a .repeat() transformation at the end.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
|
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
|
||||||
|
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/worker.pb.h"
|
||||||
#include "tensorflow/core/data/standalone.h"
|
#include "tensorflow/core/data/standalone.h"
|
||||||
#include "tensorflow/core/platform/status.h"
|
#include "tensorflow/core/platform/status.h"
|
||||||
|
|
||||||
@ -54,28 +55,14 @@ class StandaloneTaskIterator : public TaskIterator {
|
|||||||
// Interface for providing elements to task consumers.
|
// Interface for providing elements to task consumers.
|
||||||
class TaskRunner {
|
class TaskRunner {
|
||||||
public:
|
public:
|
||||||
struct Request {
|
|
||||||
// Optional consumer index indicating which consumer is making the request.
|
|
||||||
// Only needed for round-robin reads.
|
|
||||||
int64 consumer_index = -1;
|
|
||||||
// Optional round index indicating which round the consumer wants to read
|
|
||||||
// from. Consumers are expected to read from consecutive rounds, starting
|
|
||||||
// with round 0. The task runner will attempt to serve all consumer
|
|
||||||
// requests for a round from the same block of `num_consumers` iterator
|
|
||||||
// indices, where block `n` is defined as elements `n*num_consumers` to
|
|
||||||
// `(n+1)*num_consumers`.
|
|
||||||
int64 round_index = -1;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Creates a `TaskRunner` and stores it in `out`.
|
// Creates a `TaskRunner` and stores it in `out`.
|
||||||
static Status Create(const TaskDef& task_def,
|
static Status Create(const TaskDef& task_def,
|
||||||
std::unique_ptr<TaskIterator> iterator,
|
std::unique_ptr<TaskIterator> iterator,
|
||||||
std::unique_ptr<TaskRunner>& out);
|
std::unique_ptr<TaskRunner>& out);
|
||||||
virtual ~TaskRunner() = default;
|
virtual ~TaskRunner() = default;
|
||||||
// Gets the next element for the given request, storing the results in
|
// Gets the next element for the given request.
|
||||||
// `element` and `end_of_task`.
|
virtual Status GetNext(const GetElementRequest& req,
|
||||||
virtual Status GetNext(const Request& request, std::vector<Tensor>& element,
|
GetElementResponse& resp) = 0;
|
||||||
bool& end_of_task) = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// A task runner which provides elements on a first-come first-served basis.
|
// A task runner which provides elements on a first-come first-served basis.
|
||||||
@ -84,20 +71,52 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
|
|||||||
public:
|
public:
|
||||||
explicit FirstComeFirstServedTaskRunner(
|
explicit FirstComeFirstServedTaskRunner(
|
||||||
std::unique_ptr<TaskIterator> iterator);
|
std::unique_ptr<TaskIterator> iterator);
|
||||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
Status GetNext(const GetElementRequest& req,
|
||||||
bool& end_of_task) override;
|
GetElementResponse& resp) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<TaskIterator> iterator_;
|
std::unique_ptr<TaskIterator> iterator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Thread for prefetching a round worth of elements.
|
||||||
|
class PrefetchThread {
|
||||||
|
public:
|
||||||
|
explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
|
||||||
|
int64 round_size);
|
||||||
|
~PrefetchThread();
|
||||||
|
// Runs the prefetch thread. It runs until an error is encountered or the
|
||||||
|
// destructor is called.
|
||||||
|
void Run();
|
||||||
|
// Fills `out` with a round of data. Waits for up to `wait_us` micoseconds
|
||||||
|
// before giving up and returning with `out` empty. A negative `wait_us`
|
||||||
|
// signals to wait indefinitely.
|
||||||
|
Status FillBuffer(int64 wait_us, std::vector<std::vector<Tensor>>& out);
|
||||||
|
// Returns the status for any failures encountered by the prefetch thread.
|
||||||
|
Status GetStatus();
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::unique_ptr<TaskIterator> iterator_;
|
||||||
|
const int64 round_size_;
|
||||||
|
mutex mu_;
|
||||||
|
// Buffered results for the next round.
|
||||||
|
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
|
||||||
|
// The status if the prefetch thread fails.
|
||||||
|
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
|
||||||
|
// Thread which constantly tries to fill `buffer_` up with
|
||||||
|
// `num_consumers` elements.
|
||||||
|
std::unique_ptr<Thread> thread_;
|
||||||
|
// Condition variable notified when elements are added to or removed from
|
||||||
|
// `buffer_`, or when `status_` is changed.
|
||||||
|
condition_variable cv_;
|
||||||
|
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||||
|
};
|
||||||
|
|
||||||
// A task runner which enforces round-robin order for consuming a task's
|
// A task runner which enforces round-robin order for consuming a task's
|
||||||
// elements. Requests must provide a consumer index and element index.
|
// elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
|
||||||
// `RoundRobinTaskRunner` provides elements in a series of "rounds". In each
|
// In each successive round, the runner waits to receive requests from all
|
||||||
// successive round, the runner waits to receive requests from all consumers.
|
// consumers. These requests are blocked until all requests arrive. Once all
|
||||||
// These requests are blocked until all requests arrive. Once all requests
|
// requests arrive, the runner hands out elements to consumers in order of their
|
||||||
// arrive, the runner hands out elements to consumers in order of their consumer
|
// consumer indices.
|
||||||
// indices.
|
|
||||||
//
|
//
|
||||||
// Consumers are expected to successively request consecutive element indices,
|
// Consumers are expected to successively request consecutive element indices,
|
||||||
// starting at 0. The same element can be requested multiple times by the same
|
// starting at 0. The same element can be requested multiple times by the same
|
||||||
@ -113,28 +132,37 @@ class RoundRobinTaskRunner : public TaskRunner {
|
|||||||
public:
|
public:
|
||||||
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
|
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
|
||||||
int64 num_consumers);
|
int64 num_consumers);
|
||||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
|
||||||
bool& end_of_task) override;
|
Status GetNext(const GetElementRequest& req,
|
||||||
|
GetElementResponse& resp) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Fills `buffer_` with `num_consumers_` elements.
|
// Prepares a full round of data. `wait_us` indicates how long to wait before
|
||||||
Status FillBuffer();
|
// skipping if a full round of data is not yet ready.
|
||||||
|
Status PrepareFullRound(int64 wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
// Prepares a partial round to get consumers back in sync.
|
||||||
|
Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||||
|
Status ValidateRequest(const GetElementRequest& req);
|
||||||
|
// Prepares data for the next round, blocking until the round is ready to
|
||||||
|
// start.
|
||||||
|
Status PrepareRound(const GetElementRequest& req);
|
||||||
const int64 num_consumers_;
|
const int64 num_consumers_;
|
||||||
std::unique_ptr<TaskIterator> iterator_;
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
// Condition variable notified whenever we start a new round of round-robin.
|
// Condition variable notified whenever we start a new round of round-robin.
|
||||||
condition_variable new_round_cv_;
|
condition_variable new_round_cv_;
|
||||||
// Map from round number to consumers waiting for data from that round.
|
// Map from round number to requests waiting for data from that round.
|
||||||
absl::flat_hash_map<int64, absl::flat_hash_set<int64>> requests_
|
absl::flat_hash_map<int64, absl::flat_hash_set<const GetElementRequest*>>
|
||||||
TF_GUARDED_BY(mu_);
|
requests_ TF_GUARDED_BY(mu_);
|
||||||
// Index of the first round we plan to serve. At startup, this is the minimum
|
// Index of the first round we plan to serve. At startup, this is the minimum
|
||||||
// of all requested element indices.
|
// of all requested element indices.
|
||||||
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
|
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
|
||||||
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
|
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
|
||||||
|
bool round_skipped_ TF_GUARDED_BY(mu_) = false;
|
||||||
// Buffered results for the current round.
|
// Buffered results for the current round.
|
||||||
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
|
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
|
||||||
bool end_of_task_ TF_GUARDED_BY(mu_) = false;
|
// Thread which constantly tries to prepare `num_consumers` elements for the
|
||||||
|
// next round.
|
||||||
|
PrefetchThread prefetch_thread_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -13,6 +13,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/data/service/task_runner.h"
|
#include "tensorflow/core/data/service/task_runner.h"
|
||||||
|
|
||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
|
#include "tensorflow/core/data/compression_utils.h"
|
||||||
|
#include "tensorflow/core/data/dataset.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/worker.pb.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
@ -31,7 +34,10 @@ class TestTaskIterator : public TaskIterator {
|
|||||||
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
|
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
|
||||||
end_of_sequence = index_ >= elements_.size();
|
end_of_sequence = index_ >= elements_.size();
|
||||||
if (!end_of_sequence) {
|
if (!end_of_sequence) {
|
||||||
element = elements_[index_];
|
CompressedElement compressed;
|
||||||
|
TF_RETURN_IF_ERROR(CompressElement(elements_[index_], &compressed));
|
||||||
|
element.emplace_back(DT_VARIANT, TensorShape({}));
|
||||||
|
element[0].scalar<Variant>()() = std::move(compressed);
|
||||||
index_ = (index_ + 1) % elements_.size();
|
index_ = (index_ + 1) % elements_.size();
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -47,16 +53,22 @@ class TestTaskIterator : public TaskIterator {
|
|||||||
// Reads from the task runner, storing results in `*output`.
|
// Reads from the task runner, storing results in `*output`.
|
||||||
Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index,
|
Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index,
|
||||||
TaskRunner& task_runner, std::vector<int64>& output) {
|
TaskRunner& task_runner, std::vector<int64>& output) {
|
||||||
bool end_of_sequence = false;
|
|
||||||
for (int64 next_index = start_index; next_index < end_index; ++next_index) {
|
for (int64 next_index = start_index; next_index < end_index; ++next_index) {
|
||||||
TaskRunner::Request request;
|
GetElementRequest request;
|
||||||
request.round_index = next_index;
|
request.set_round_index(next_index);
|
||||||
request.consumer_index = consumer_index;
|
request.set_consumer_index(consumer_index);
|
||||||
std::vector<Tensor> element;
|
request.set_skipped_previous_round(false);
|
||||||
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence));
|
request.set_allow_skip(false);
|
||||||
if (!end_of_sequence) {
|
GetElementResponse response;
|
||||||
output.push_back(element[0].flat<int64>()(0));
|
do {
|
||||||
|
TF_RETURN_IF_ERROR(task_runner.GetNext(request, response));
|
||||||
|
if (!response.end_of_sequence()) {
|
||||||
|
std::vector<Tensor> uncompressed;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
UncompressElement(response.compressed_element(), &uncompressed));
|
||||||
|
output.push_back(uncompressed[0].flat<int64>()(0));
|
||||||
}
|
}
|
||||||
|
} while (response.skip_task());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -71,12 +83,13 @@ TEST(FirstComeFirstServedTaskRunner, GetNext) {
|
|||||||
}
|
}
|
||||||
FirstComeFirstServedTaskRunner runner(
|
FirstComeFirstServedTaskRunner runner(
|
||||||
absl::make_unique<TestTaskIterator>(elements));
|
absl::make_unique<TestTaskIterator>(elements));
|
||||||
TaskRunner::Request request;
|
GetElementRequest request;
|
||||||
|
GetElementResponse response;
|
||||||
for (auto& expected_element : elements) {
|
for (auto& expected_element : elements) {
|
||||||
|
TF_ASSERT_OK(runner.GetNext(request, response));
|
||||||
|
ASSERT_FALSE(response.end_of_sequence());
|
||||||
std::vector<Tensor> element;
|
std::vector<Tensor> element;
|
||||||
bool end_of_sequence;
|
TF_ASSERT_OK(UncompressElement(response.compressed_element(), &element));
|
||||||
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
|
|
||||||
ASSERT_FALSE(end_of_sequence);
|
|
||||||
ASSERT_EQ(element.size(), 1);
|
ASSERT_EQ(element.size(), 1);
|
||||||
test::ExpectEqual(element[0], expected_element[0]);
|
test::ExpectEqual(element[0], expected_element[0]);
|
||||||
}
|
}
|
||||||
|
@ -23,6 +23,11 @@ message GetElementRequest {
|
|||||||
oneof optional_round_index {
|
oneof optional_round_index {
|
||||||
int64 round_index = 3;
|
int64 round_index = 3;
|
||||||
}
|
}
|
||||||
|
// Whether the previous round was skipped. This information is needed by the
|
||||||
|
// worker to recover after restarts.
|
||||||
|
bool skipped_previous_round = 4;
|
||||||
|
// Whether to skip the round if data isn't ready fast enough.
|
||||||
|
bool allow_skip = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message GetElementResponse {
|
message GetElementResponse {
|
||||||
@ -30,6 +35,8 @@ message GetElementResponse {
|
|||||||
CompressedElement compressed_element = 3;
|
CompressedElement compressed_element = 3;
|
||||||
// Boolean to indicate whether the iterator has been exhausted.
|
// Boolean to indicate whether the iterator has been exhausted.
|
||||||
bool end_of_sequence = 2;
|
bool end_of_sequence = 2;
|
||||||
|
// Indicates whether the round was skipped.
|
||||||
|
bool skip_task = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
|
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
|
||||||
|
@ -180,8 +180,6 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
|||||||
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||||
GetElementResponse* response) {
|
GetElementResponse* response) {
|
||||||
VLOG(3) << "Received GetElement request for task " << request->task_id();
|
VLOG(3) << "Received GetElement request for task " << request->task_id();
|
||||||
bool end_of_sequence = false;
|
|
||||||
std::vector<tensorflow::Tensor> outputs;
|
|
||||||
Task* task;
|
Task* task;
|
||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -207,54 +205,15 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
task = it->second.get();
|
task = it->second.get();
|
||||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||||
}
|
}
|
||||||
TaskRunner::Request get_next_request;
|
TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *response));
|
||||||
if (request->optional_consumer_index_case() ==
|
if (response->end_of_sequence()) {
|
||||||
GetElementRequest::kConsumerIndex) {
|
|
||||||
get_next_request.consumer_index = request->consumer_index();
|
|
||||||
}
|
|
||||||
if (request->optional_round_index_case() == GetElementRequest::kRoundIndex) {
|
|
||||||
get_next_request.round_index = request->round_index();
|
|
||||||
}
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
|
|
||||||
if (end_of_sequence) {
|
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||||
pending_completed_tasks_.insert(request->task_id());
|
pending_completed_tasks_.insert(request->task_id());
|
||||||
task_completion_cv_.notify_one();
|
task_completion_cv_.notify_one();
|
||||||
}
|
} else if (!response->skip_task()) {
|
||||||
|
|
||||||
if (!end_of_sequence) {
|
|
||||||
VLOG(3) << "Producing an element for task " << request->task_id();
|
VLOG(3) << "Producing an element for task " << request->task_id();
|
||||||
if (outputs.size() != 1) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
"Expected dataset to produce a single scalar variant tensor, but the "
|
|
||||||
"dataset produced ",
|
|
||||||
outputs.size(), " outputs");
|
|
||||||
}
|
}
|
||||||
if (outputs[0].dtype() != DT_VARIANT) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
"Expected dataset to produce a single scalar variant tensor, but "
|
|
||||||
"the dataset produced a tensor with type ",
|
|
||||||
DataTypeString(outputs[0].dtype()));
|
|
||||||
}
|
|
||||||
if (!TensorShapeUtils::IsScalar(outputs[0].shape())) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
"Expected dataset to produce a single scalar variant tensor, but "
|
|
||||||
"the dataset produced a tensor with shape ",
|
|
||||||
outputs[0].shape());
|
|
||||||
}
|
|
||||||
Variant& variant = outputs[0].scalar<Variant>()();
|
|
||||||
CompressedElement* compressed = variant.get<CompressedElement>();
|
|
||||||
if (compressed == nullptr) {
|
|
||||||
return errors::FailedPrecondition(
|
|
||||||
"Expected dataset to produce a CompressedElement variant tensor, but "
|
|
||||||
"it produced ",
|
|
||||||
variant.TypeName());
|
|
||||||
}
|
|
||||||
*response->mutable_compressed_element() = *compressed;
|
|
||||||
}
|
|
||||||
response->set_end_of_sequence(end_of_sequence);
|
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -168,6 +168,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/core/data/service:data_service",
|
"//tensorflow/core/data/service:data_service",
|
||||||
"//tensorflow/core/data/service:dispatcher_proto_cc",
|
"//tensorflow/core/data/service:dispatcher_proto_cc",
|
||||||
"//tensorflow/core/data/service:grpc_util",
|
"//tensorflow/core/data/service:grpc_util",
|
||||||
|
"//tensorflow/core/data/service:worker_proto_cc",
|
||||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||||
"//tensorflow/core/kernels/data:dataset_utils",
|
"//tensorflow/core/kernels/data:dataset_utils",
|
||||||
"//tensorflow/core/kernels/data:name_utils",
|
"//tensorflow/core/kernels/data:name_utils",
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/data/service/data_service.h"
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
|
#include "tensorflow/core/data/service/worker.pb.h"
|
||||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
#include "tensorflow/core/framework/model.h"
|
#include "tensorflow/core/framework/model.h"
|
||||||
@ -295,6 +296,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool skip = true;
|
||||||
|
while (skip) {
|
||||||
while ((results_.empty() || !results_.front().ready) &&
|
while ((results_.empty() || !results_.front().ready) &&
|
||||||
!(job_finished_ && num_running_worker_threads_ == 0) &&
|
!(job_finished_ && num_running_worker_threads_ == 0) &&
|
||||||
!cancelled_ && status_.ok()) {
|
!cancelled_ && status_.ok()) {
|
||||||
@ -319,6 +322,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
VLOG(3) << "Returning from GetNext with end_of_sequence";
|
VLOG(3) << "Returning from GetNext with end_of_sequence";
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
skip = results_.front().skip;
|
||||||
|
if (skip) {
|
||||||
|
results_.pop();
|
||||||
|
worker_thread_cv_.notify_one();
|
||||||
|
}
|
||||||
|
}
|
||||||
*end_of_sequence = results_.front().end_of_sequence;
|
*end_of_sequence = results_.front().end_of_sequence;
|
||||||
if (!*end_of_sequence) {
|
if (!*end_of_sequence) {
|
||||||
out_tensors->swap(results_.front().element);
|
out_tensors->swap(results_.front().element);
|
||||||
@ -378,6 +387,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
const std::unique_ptr<DataServiceWorkerClient> worker;
|
const std::unique_ptr<DataServiceWorkerClient> worker;
|
||||||
// The next round to read from the task.
|
// The next round to read from the task.
|
||||||
int64 round = 0;
|
int64 round = 0;
|
||||||
|
bool skipped_previous_round = false;
|
||||||
// Indicates whether a worker thread is currently processing the task.
|
// Indicates whether a worker thread is currently processing the task.
|
||||||
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
|
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||||
// Indicates whether the worker has returned end_of_sequence for the task.
|
// Indicates whether the worker has returned end_of_sequence for the task.
|
||||||
@ -390,6 +400,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
|
bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||||
std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
|
std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
|
||||||
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
|
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||||
|
bool skip TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Periodically refresh the task list.
|
// Periodically refresh the task list.
|
||||||
@ -677,25 +688,24 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
return profiler::TraceMeEncode(
|
return profiler::TraceMeEncode(
|
||||||
{{"address", task->info.worker_address()}});
|
{{"address", task->info.worker_address()}});
|
||||||
});
|
});
|
||||||
CompressedElement compressed;
|
GetElementResponse resp;
|
||||||
bool end_of_sequence;
|
|
||||||
for (int num_retries = 0;; ++num_retries) {
|
for (int num_retries = 0;; ++num_retries) {
|
||||||
absl::optional<int64> consumer_index = dataset()->consumer_index_;
|
GetElementRequest req;
|
||||||
absl::optional<int64> round_index;
|
|
||||||
if (StrictRoundRobin()) {
|
if (StrictRoundRobin()) {
|
||||||
round_index = task->round;
|
req.set_consumer_index(dataset()->consumer_index_.value());
|
||||||
|
req.set_round_index(task->round);
|
||||||
|
req.set_allow_skip(true);
|
||||||
VLOG(3) << "Requesting element from consumer index "
|
VLOG(3) << "Requesting element from consumer index "
|
||||||
<< consumer_index.value() << ", round "
|
<< req.consumer_index() << ", round " << req.round_index();
|
||||||
<< round_index.value();
|
|
||||||
activity.AppendMetadata([&]() {
|
activity.AppendMetadata([&]() {
|
||||||
return profiler::TraceMeEncode(
|
return profiler::TraceMeEncode(
|
||||||
{{"consumer_index", consumer_index.value()},
|
{{"consumer_index", req.consumer_index()},
|
||||||
{"round_index", round_index.value()}});
|
{"round_index", req.round_index()}});
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Status s =
|
req.set_task_id(task->info.task_id());
|
||||||
task->worker->GetElement(task->info.task_id(), consumer_index,
|
req.set_skipped_previous_round(task->skipped_previous_round);
|
||||||
round_index, compressed, end_of_sequence);
|
Status s = task->worker->GetElement(req, resp);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@ -709,7 +719,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
// If `UpdateTaskThreads` finds that the task has been cancelled, it
|
// If `UpdateTaskThreads` finds that the task has been cancelled, it
|
||||||
// will set end_of_sequence to `true`.
|
// will set end_of_sequence to `true`.
|
||||||
if (task->end_of_sequence || cancelled_) {
|
if (task->end_of_sequence || cancelled_) {
|
||||||
end_of_sequence = true;
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -734,21 +743,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<Tensor> element;
|
std::vector<Tensor> element;
|
||||||
if (!end_of_sequence) {
|
if (resp.has_compressed_element()) {
|
||||||
Tensor tensor(DT_VARIANT, TensorShape{});
|
Tensor tensor(DT_VARIANT, TensorShape{});
|
||||||
tensor.scalar<Variant>()() = std::move(compressed);
|
tensor.scalar<Variant>()() = std::move(resp.compressed_element());
|
||||||
element.push_back(tensor);
|
element.push_back(tensor);
|
||||||
}
|
}
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
result.ready = true;
|
result.ready = true;
|
||||||
result.end_of_sequence = end_of_sequence;
|
result.end_of_sequence = resp.end_of_sequence();
|
||||||
if (end_of_sequence) {
|
if (resp.has_compressed_element()) {
|
||||||
|
task->skipped_previous_round = false;
|
||||||
|
task->round++;
|
||||||
|
result.element = std::move(element);
|
||||||
|
} else if (resp.skip_task()) {
|
||||||
|
task->skipped_previous_round = true;
|
||||||
|
task->round++;
|
||||||
|
result.skip = true;
|
||||||
|
} else {
|
||||||
task->end_of_sequence = true;
|
task->end_of_sequence = true;
|
||||||
finished_tasks_++;
|
finished_tasks_++;
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
result.element = std::move(element);
|
if (enqueue_result && !resp.end_of_sequence()) {
|
||||||
if (enqueue_result) {
|
|
||||||
results_.push(std::move(result));
|
results_.push(std::move(result));
|
||||||
}
|
}
|
||||||
get_next_cv_.notify_all();
|
get_next_cv_.notify_all();
|
||||||
|
@ -291,8 +291,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
cluster = self.create_cluster(num_workers=num_workers)
|
cluster = self.create_cluster(num_workers=num_workers)
|
||||||
# Round robin reads can cause slow cluster shutdown.
|
# Round robin reads can cause slow cluster shutdown.
|
||||||
data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
|
data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
|
||||||
num_elements = 100
|
ds = dataset_ops.Dataset.range(10000000)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
|
||||||
ds = ds.repeat()
|
ds = ds.repeat()
|
||||||
consumers = []
|
consumers = []
|
||||||
for consumer_index in range(num_consumers):
|
for consumer_index in range(num_consumers):
|
||||||
@ -309,18 +308,15 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
|||||||
lambda x: x,
|
lambda x: x,
|
||||||
cycle_length=num_consumers,
|
cycle_length=num_consumers,
|
||||||
num_parallel_calls=num_consumers)
|
num_parallel_calls=num_consumers)
|
||||||
ds = ds.take(2 * num_elements * num_workers)
|
ds = ds.take(1000)
|
||||||
results = self.getDatasetOutput(ds, requires_initialization=True)
|
results = self.getDatasetOutput(ds, requires_initialization=True)
|
||||||
|
|
||||||
expected = []
|
for i in range(0, len(results), num_consumers):
|
||||||
round_index = 0
|
self.assertEqual(0, results[i] % num_consumers)
|
||||||
while len(expected) < len(results):
|
# Check that each group of `num_consumers` results are consecutive.
|
||||||
for _ in range(num_workers):
|
for offset in range(1, num_consumers):
|
||||||
for consumer in range(num_consumers):
|
if i + offset < len(results):
|
||||||
expected.append(
|
self.assertEqual(results[i] + offset, results[i + offset])
|
||||||
(round_index * num_consumers + consumer) % num_elements)
|
|
||||||
round_index += 1
|
|
||||||
self.assertEqual(results, expected)
|
|
||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def testRoundRobinBucketizing(self):
|
def testRoundRobinBucketizing(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user