[tf.data service] Skip round robin rounds when data isn't ready.
PiperOrigin-RevId: 356348715 Change-Id: I8ac227a098d49bd8a3fd6c96b93ac855df80a121
This commit is contained in:
parent
0092ebe4c7
commit
2cc0ab3c0c
@ -57,6 +57,7 @@ tf_proto_library(
|
||||
],
|
||||
visibility = [
|
||||
":data_transfer_visibility",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
@ -95,6 +96,7 @@ cc_library(
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":grpc_util",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
@ -415,6 +417,7 @@ cc_library(
|
||||
hdrs = ["task_runner.h"],
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/data:compression_utils",
|
||||
@ -428,11 +431,14 @@ tf_cc_test(
|
||||
srcs = ["task_runner_test.cc"],
|
||||
deps = [
|
||||
":task_runner",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/data:compression_utils",
|
||||
"//tensorflow/core/data:dataset_proto_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
@ -248,25 +248,14 @@ class GrpcDataTransferClient : public DataTransferClient {
|
||||
stub_ = WorkerService::NewStub(channel);
|
||||
}
|
||||
|
||||
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element,
|
||||
bool& end_of_sequence) override {
|
||||
Status GetElement(const GetElementRequest& req,
|
||||
GetElementResponse& resp) override {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled("Client was cancelled.");
|
||||
}
|
||||
}
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
if (consumer_index.has_value()) {
|
||||
req.set_consumer_index(consumer_index.value());
|
||||
}
|
||||
if (round_index.has_value()) {
|
||||
req.set_round_index(round_index.value());
|
||||
}
|
||||
GetElementResponse resp;
|
||||
grpc::ClientContext ctx;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
@ -280,10 +269,6 @@ class GrpcDataTransferClient : public DataTransferClient {
|
||||
if (!s.ok()) {
|
||||
return grpc_util::WrapError("Failed to get element", s);
|
||||
}
|
||||
end_of_sequence = resp.end_of_sequence();
|
||||
if (!end_of_sequence) {
|
||||
element = std::move(*resp.mutable_compressed_element());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -324,14 +309,10 @@ class GrpcTransferClientRegistrar {
|
||||
};
|
||||
static GrpcTransferClientRegistrar registrar;
|
||||
|
||||
Status DataServiceWorkerClient::GetElement(int64 task_id,
|
||||
absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element,
|
||||
bool& end_of_sequence) {
|
||||
Status DataServiceWorkerClient::GetElement(const GetElementRequest& req,
|
||||
GetElementResponse& resp) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
return client_->GetElement(task_id, consumer_index, round_index, element,
|
||||
end_of_sequence);
|
||||
return client_->GetElement(req, resp);
|
||||
}
|
||||
|
||||
Status DataServiceWorkerClient::EnsureInitialized() {
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/data_transfer.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
@ -144,14 +145,8 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
: DataServiceClientBase(address, protocol),
|
||||
transfer_protocol_(transfer_protocol) {}
|
||||
|
||||
// Fetches the next element for the specified task_id. The optional
|
||||
// `consumer_index` and `round_index` must be specified for tasks which use
|
||||
// round-robin ordering. The element's compressed tensors will be stored in
|
||||
// `element`. If no element is available, `end_of_sequence` will be `true`,
|
||||
// and `element` will be left unchanged.
|
||||
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element, bool& end_of_sequence);
|
||||
// Fetches an element from the worker.
|
||||
Status GetElement(const GetElementRequest& req, GetElementResponse& resp);
|
||||
|
||||
// Makes a best effort to cancel all outstanding calls in progress for the
|
||||
// client, and causes further calls to return Cancelled status.
|
||||
|
@ -38,13 +38,9 @@ class DataTransferClient {
|
||||
std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>;
|
||||
virtual ~DataTransferClient() = default;
|
||||
|
||||
// Fetches the next element for the specified task_id. The element's
|
||||
// compressed tensors will be stored in `element`. If no element is available,
|
||||
// `end_of_sequence` will be `true`, and `element` will be left unchanged.
|
||||
virtual Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
tensorflow::data::CompressedElement& element,
|
||||
bool& end_of_sequence) = 0;
|
||||
// Fetches the next element.
|
||||
virtual Status GetElement(const GetElementRequest& req,
|
||||
GetElementResponse& resp) = 0;
|
||||
|
||||
// Makes a best effort to cancel all outstanding calls in progress for the
|
||||
// client, and causes further calls to return Cancelled status.
|
||||
|
@ -29,6 +29,43 @@ namespace {
|
||||
// Unavailable error. This prevents the server from hanging on shutdown when
|
||||
// some round-robin consumers exit earlier than others.
|
||||
const int64 kTimeoutUs = 60 * 1000 * 1000; // 1 minute.
|
||||
// Time to wait before skipping a round if data still isn't available.
|
||||
const int64 kWaitBeforeSkipUs = 100 * 1000; // 100ms.
|
||||
|
||||
// Interprets `element` as a size-1 vector containing a CompressedElement, and
|
||||
// moves the element into `resp`. Returns an error if `element` is of unexpected
|
||||
// size, type, or shape.
|
||||
Status MoveCompressedElement(std::vector<Tensor>&& element,
|
||||
GetElementResponse& resp) {
|
||||
if (element.size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a single scalar variant tensor, but the "
|
||||
"dataset produced ",
|
||||
element.size(), " outputs");
|
||||
}
|
||||
if (element[0].dtype() != DT_VARIANT) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a single scalar variant tensor, but "
|
||||
"the dataset produced a tensor with type ",
|
||||
DataTypeString(element[0].dtype()));
|
||||
}
|
||||
if (!TensorShapeUtils::IsScalar(element[0].shape())) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a single scalar variant tensor, but "
|
||||
"the dataset produced a tensor with shape ",
|
||||
element[0].shape());
|
||||
}
|
||||
Variant& variant = element[0].scalar<Variant>()();
|
||||
CompressedElement* compressed = variant.get<CompressedElement>();
|
||||
if (compressed == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a CompressedElement variant tensor, but "
|
||||
"it produced ",
|
||||
variant.TypeName());
|
||||
}
|
||||
*resp.mutable_compressed_element() = *compressed;
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StandaloneTaskIterator::StandaloneTaskIterator(
|
||||
@ -71,97 +108,211 @@ FirstComeFirstServedTaskRunner::FirstComeFirstServedTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator)
|
||||
: iterator_(std::move(iterator)) {}
|
||||
|
||||
Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
|
||||
std::vector<Tensor>& element,
|
||||
bool& end_of_task) {
|
||||
return iterator_->GetNext(element, end_of_task);
|
||||
Status FirstComeFirstServedTaskRunner::GetNext(const GetElementRequest& req,
|
||||
GetElementResponse& resp) {
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_task;
|
||||
resp.set_skip_task(false);
|
||||
TF_RETURN_IF_ERROR(iterator_->GetNext(element, end_of_task));
|
||||
resp.set_end_of_sequence(end_of_task);
|
||||
if (!end_of_task) {
|
||||
return MoveCompressedElement(std::move(element), resp);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
RoundRobinTaskRunner::RoundRobinTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
|
||||
: num_consumers_(num_consumers),
|
||||
iterator_(std::move(iterator)),
|
||||
buffer_(num_consumers_) {
|
||||
buffer_(num_consumers_),
|
||||
prefetch_thread_(std::move(iterator), num_consumers_) {
|
||||
VLOG(1) << "Creating task runner for distributing data round-robin to "
|
||||
<< num_consumers << " consumers";
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::GetNext(const Request& request,
|
||||
std::vector<Tensor>& element,
|
||||
bool& end_of_task) {
|
||||
if (request.consumer_index < 0 || request.round_index < 0) {
|
||||
Status RoundRobinTaskRunner::ValidateRequest(const GetElementRequest& req) {
|
||||
if (req.consumer_index() < 0 || req.round_index() < 0) {
|
||||
return errors::FailedPrecondition(
|
||||
"RoundRobinTaskRunner needs to know the consumer index and element "
|
||||
"index of each request.");
|
||||
}
|
||||
if (request.consumer_index >= num_consumers_) {
|
||||
if (req.consumer_index() >= num_consumers_) {
|
||||
return errors::FailedPrecondition(
|
||||
"Requesting data for consumer index ", request.consumer_index,
|
||||
"Requesting data for consumer index ", req.consumer_index(),
|
||||
", but the task is configured for only ", num_consumers_, " consumers");
|
||||
}
|
||||
VLOG(2) << "Received request from consumer index " << request.consumer_index
|
||||
<< " for round " << request.round_index;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
absl::flat_hash_set<int64>& round = requests_[request.round_index];
|
||||
first_round_ = std::min(first_round_, request.round_index);
|
||||
round.insert(request.consumer_index);
|
||||
if (current_round_ < request.round_index &&
|
||||
round.size() == num_consumers_) {
|
||||
VLOG(1) << "Starting normal round with round index "
|
||||
<< request.round_index;
|
||||
// This was the last request to arrive, time to start a new round.
|
||||
TF_RETURN_IF_ERROR(FillBuffer());
|
||||
VLOG(1) << "Finished preparing data for round " << request.round_index;
|
||||
current_round_ = request.round_index;
|
||||
new_round_cv_.notify_all();
|
||||
}
|
||||
if (current_round_ < 0 &&
|
||||
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
|
||||
num_consumers_) {
|
||||
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
|
||||
<< " consumers";
|
||||
// Indicates that we need a partial round to get consumers back in sync.
|
||||
TF_RETURN_IF_ERROR(FillBuffer());
|
||||
current_round_ = first_round_;
|
||||
new_round_cv_.notify_all();
|
||||
}
|
||||
while (current_round_ < request.round_index) {
|
||||
std::cv_status s =
|
||||
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
|
||||
if (s == std::cv_status::timeout) {
|
||||
// Clients will retry Unavailable.
|
||||
return errors::Unavailable(
|
||||
"Timeout waiting for other round-robin consumers to be ready.");
|
||||
}
|
||||
}
|
||||
end_of_task = end_of_task_;
|
||||
}
|
||||
if (!end_of_task) {
|
||||
element.clear();
|
||||
tf_shared_lock l(mu_);
|
||||
for (auto& component : buffer_[request.consumer_index]) {
|
||||
element.push_back(tensor::DeepCopy(component));
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Returning to consumer " << request.consumer_index << " for round "
|
||||
<< request.round_index;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int i = 0; i < num_consumers_; ++i) {
|
||||
buffer_[i].clear();
|
||||
bool end_of_sequence;
|
||||
TF_RETURN_IF_ERROR(iterator_->GetNext(buffer_[i], end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
return errors::FailedPrecondition(
|
||||
"Encountered end of sequence on a round-robin read iterator. Please "
|
||||
"ensure that the dataset used for round-robin reading has infinite "
|
||||
"cardinality, e.g. by adding a .repeat() transformation at the end.");
|
||||
Status RoundRobinTaskRunner::PrepareFullRound(int64 wait_us)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
VLOG(1) << "Preparing full round for index " << current_round_;
|
||||
// This was the last request to arrive, time to start a new round.
|
||||
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(wait_us, buffer_));
|
||||
round_skipped_ = buffer_.empty();
|
||||
new_round_cv_.notify_all();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::PreparePartialRound()
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
|
||||
<< " consumers";
|
||||
current_round_ = first_round_;
|
||||
new_round_cv_.notify_all();
|
||||
// Indicates that we need a partial round to get consumers back in sync.
|
||||
auto next_round_request = *(requests_[first_round_ + 1].begin());
|
||||
if (next_round_request->skipped_previous_round()) {
|
||||
VLOG(1) << "Skipping partial round";
|
||||
round_skipped_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(prefetch_thread_.FillBuffer(/*wait_us=*/-1, buffer_));
|
||||
round_skipped_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::PrepareRound(const GetElementRequest& req) {
|
||||
mutex_lock l(mu_);
|
||||
absl::flat_hash_set<const GetElementRequest*>& round =
|
||||
requests_[req.round_index()];
|
||||
first_round_ = std::min(first_round_, req.round_index());
|
||||
round.insert(&req);
|
||||
if (current_round_ < req.round_index() && round.size() == num_consumers_) {
|
||||
current_round_ = req.round_index();
|
||||
int64 wait_us = kWaitBeforeSkipUs;
|
||||
if (!req.allow_skip()) {
|
||||
wait_us = -1;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(PrepareFullRound(wait_us));
|
||||
}
|
||||
if (current_round_ < 0 &&
|
||||
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
|
||||
num_consumers_) {
|
||||
TF_RETURN_IF_ERROR(PreparePartialRound());
|
||||
}
|
||||
while (current_round_ < req.round_index()) {
|
||||
TF_RETURN_IF_ERROR(prefetch_thread_.GetStatus());
|
||||
std::cv_status s =
|
||||
new_round_cv_.wait_for(l, std::chrono::microseconds(kTimeoutUs));
|
||||
if (s == std::cv_status::timeout) {
|
||||
// Clients will retry Unavailable.
|
||||
return errors::Unavailable(
|
||||
"Timeout waiting for other round-robin consumers to be ready.");
|
||||
}
|
||||
}
|
||||
return prefetch_thread_.GetStatus();
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::GetNext(const GetElementRequest& req,
|
||||
GetElementResponse& resp) {
|
||||
TF_RETURN_IF_ERROR(ValidateRequest(req));
|
||||
resp.set_end_of_sequence(false);
|
||||
VLOG(2) << "Received request from consumer index " << req.consumer_index()
|
||||
<< " for round " << req.round_index();
|
||||
TF_RETURN_IF_ERROR(PrepareRound(req));
|
||||
tf_shared_lock l(mu_);
|
||||
resp.set_skip_task(round_skipped_);
|
||||
if (round_skipped_) {
|
||||
VLOG(1) << "Buffer not ready, skipping round " << current_round_
|
||||
<< " for consumer " << req.consumer_index();
|
||||
return Status::OK();
|
||||
}
|
||||
std::vector<Tensor> element;
|
||||
for (auto& component : buffer_[req.consumer_index()]) {
|
||||
element.push_back(tensor::DeepCopy(component));
|
||||
}
|
||||
if (VLOG_IS_ON(2)) {
|
||||
int64 size = 0;
|
||||
for (auto& component : element) {
|
||||
size += component.TotalBytes();
|
||||
}
|
||||
VLOG(2) << "Returning to consumer " << req.consumer_index() << " for round "
|
||||
<< req.round_index() << ". element size " << size;
|
||||
}
|
||||
return MoveCompressedElement(std::move(element), resp);
|
||||
}
|
||||
|
||||
PrefetchThread::PrefetchThread(std::unique_ptr<TaskIterator> iterator,
|
||||
int64 round_size)
|
||||
: iterator_(std::move(iterator)), round_size_(round_size) {
|
||||
thread_ = absl::WrapUnique(
|
||||
Env::Default()->StartThread({}, "round-robin-prefetch", [&] { Run(); }));
|
||||
}
|
||||
|
||||
PrefetchThread::~PrefetchThread() {
|
||||
mutex_lock l(mu_);
|
||||
cancelled_ = true;
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
void PrefetchThread::Run() {
|
||||
while (true) {
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
while (!cancelled_ && buffer_.size() >= round_size_) {
|
||||
cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_sequence;
|
||||
Status s = iterator_->GetNext(element, end_of_sequence);
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
status_ = s;
|
||||
cv_.notify_all();
|
||||
return;
|
||||
}
|
||||
if (end_of_sequence) {
|
||||
mutex_lock l(mu_);
|
||||
status_ = errors::FailedPrecondition(
|
||||
"Encountered end of sequence on a round-robin read iterator. "
|
||||
"Please ensure that the dataset used for round-robin reading has "
|
||||
"infinite cardinality, e.g. by adding a .repeat() transformation "
|
||||
"at the end.");
|
||||
cv_.notify_all();
|
||||
return;
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
buffer_.push_back(std::move(element));
|
||||
cv_.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
Status PrefetchThread::FillBuffer(int64 wait_us,
|
||||
std::vector<std::vector<Tensor>>& out) {
|
||||
int64 start_us = Env::Default()->NowMicros();
|
||||
out.clear();
|
||||
mutex_lock l(mu_);
|
||||
while (buffer_.size() < round_size_ && !cancelled_ && status_.ok()) {
|
||||
int64 remaining_us = start_us + wait_us - Env::Default()->NowMicros();
|
||||
if (wait_us >= 0 && remaining_us <= 0) {
|
||||
break;
|
||||
}
|
||||
cv_.wait_for(l, std::chrono::microseconds(remaining_us));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(status_);
|
||||
if (cancelled_) {
|
||||
return errors::Cancelled("Prefetch thread cancelled");
|
||||
}
|
||||
if (buffer_.size() < round_size_) {
|
||||
DCHECK_GE(wait_us, 0);
|
||||
return Status::OK();
|
||||
}
|
||||
for (auto& elem : buffer_) {
|
||||
out.push_back(std::move(elem));
|
||||
}
|
||||
buffer_.clear();
|
||||
cv_.notify_all();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status PrefetchThread::GetStatus() {
|
||||
mutex_lock l(mu_);
|
||||
return status_;
|
||||
}
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_DATA_SERVICE_TASK_RUNNER_H_
|
||||
|
||||
#include "tensorflow/core/data/service/common.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
@ -54,28 +55,14 @@ class StandaloneTaskIterator : public TaskIterator {
|
||||
// Interface for providing elements to task consumers.
|
||||
class TaskRunner {
|
||||
public:
|
||||
struct Request {
|
||||
// Optional consumer index indicating which consumer is making the request.
|
||||
// Only needed for round-robin reads.
|
||||
int64 consumer_index = -1;
|
||||
// Optional round index indicating which round the consumer wants to read
|
||||
// from. Consumers are expected to read from consecutive rounds, starting
|
||||
// with round 0. The task runner will attempt to serve all consumer
|
||||
// requests for a round from the same block of `num_consumers` iterator
|
||||
// indices, where block `n` is defined as elements `n*num_consumers` to
|
||||
// `(n+1)*num_consumers`.
|
||||
int64 round_index = -1;
|
||||
};
|
||||
|
||||
// Creates a `TaskRunner` and stores it in `out`.
|
||||
static Status Create(const TaskDef& task_def,
|
||||
std::unique_ptr<TaskIterator> iterator,
|
||||
std::unique_ptr<TaskRunner>& out);
|
||||
virtual ~TaskRunner() = default;
|
||||
// Gets the next element for the given request, storing the results in
|
||||
// `element` and `end_of_task`.
|
||||
virtual Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) = 0;
|
||||
// Gets the next element for the given request.
|
||||
virtual Status GetNext(const GetElementRequest& req,
|
||||
GetElementResponse& resp) = 0;
|
||||
};
|
||||
|
||||
// A task runner which provides elements on a first-come first-served basis.
|
||||
@ -84,20 +71,52 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
|
||||
public:
|
||||
explicit FirstComeFirstServedTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator);
|
||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) override;
|
||||
Status GetNext(const GetElementRequest& req,
|
||||
GetElementResponse& resp) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TaskIterator> iterator_;
|
||||
};
|
||||
|
||||
// Thread for prefetching a round worth of elements.
|
||||
class PrefetchThread {
|
||||
public:
|
||||
explicit PrefetchThread(std::unique_ptr<TaskIterator> iterator,
|
||||
int64 round_size);
|
||||
~PrefetchThread();
|
||||
// Runs the prefetch thread. It runs until an error is encountered or the
|
||||
// destructor is called.
|
||||
void Run();
|
||||
// Fills `out` with a round of data. Waits for up to `wait_us` micoseconds
|
||||
// before giving up and returning with `out` empty. A negative `wait_us`
|
||||
// signals to wait indefinitely.
|
||||
Status FillBuffer(int64 wait_us, std::vector<std::vector<Tensor>>& out);
|
||||
// Returns the status for any failures encountered by the prefetch thread.
|
||||
Status GetStatus();
|
||||
|
||||
private:
|
||||
const std::unique_ptr<TaskIterator> iterator_;
|
||||
const int64 round_size_;
|
||||
mutex mu_;
|
||||
// Buffered results for the next round.
|
||||
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
|
||||
// The status if the prefetch thread fails.
|
||||
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
|
||||
// Thread which constantly tries to fill `buffer_` up with
|
||||
// `num_consumers` elements.
|
||||
std::unique_ptr<Thread> thread_;
|
||||
// Condition variable notified when elements are added to or removed from
|
||||
// `buffer_`, or when `status_` is changed.
|
||||
condition_variable cv_;
|
||||
bool cancelled_ TF_GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
// A task runner which enforces round-robin order for consuming a task's
|
||||
// elements. Requests must provide a consumer index and element index.
|
||||
// `RoundRobinTaskRunner` provides elements in a series of "rounds". In each
|
||||
// successive round, the runner waits to receive requests from all consumers.
|
||||
// These requests are blocked until all requests arrive. Once all requests
|
||||
// arrive, the runner hands out elements to consumers in order of their consumer
|
||||
// indices.
|
||||
// elements. `RoundRobinTaskRunner` provides elements in a series of "rounds".
|
||||
// In each successive round, the runner waits to receive requests from all
|
||||
// consumers. These requests are blocked until all requests arrive. Once all
|
||||
// requests arrive, the runner hands out elements to consumers in order of their
|
||||
// consumer indices.
|
||||
//
|
||||
// Consumers are expected to successively request consecutive element indices,
|
||||
// starting at 0. The same element can be requested multiple times by the same
|
||||
@ -113,28 +132,37 @@ class RoundRobinTaskRunner : public TaskRunner {
|
||||
public:
|
||||
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
|
||||
int64 num_consumers);
|
||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) override;
|
||||
|
||||
Status GetNext(const GetElementRequest& req,
|
||||
GetElementResponse& resp) override;
|
||||
|
||||
private:
|
||||
// Fills `buffer_` with `num_consumers_` elements.
|
||||
Status FillBuffer();
|
||||
|
||||
// Prepares a full round of data. `wait_us` indicates how long to wait before
|
||||
// skipping if a full round of data is not yet ready.
|
||||
Status PrepareFullRound(int64 wait_us) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
// Prepares a partial round to get consumers back in sync.
|
||||
Status PreparePartialRound() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
|
||||
Status ValidateRequest(const GetElementRequest& req);
|
||||
// Prepares data for the next round, blocking until the round is ready to
|
||||
// start.
|
||||
Status PrepareRound(const GetElementRequest& req);
|
||||
const int64 num_consumers_;
|
||||
std::unique_ptr<TaskIterator> iterator_;
|
||||
mutex mu_;
|
||||
// Condition variable notified whenever we start a new round of round-robin.
|
||||
condition_variable new_round_cv_;
|
||||
// Map from round number to consumers waiting for data from that round.
|
||||
absl::flat_hash_map<int64, absl::flat_hash_set<int64>> requests_
|
||||
TF_GUARDED_BY(mu_);
|
||||
// Map from round number to requests waiting for data from that round.
|
||||
absl::flat_hash_map<int64, absl::flat_hash_set<const GetElementRequest*>>
|
||||
requests_ TF_GUARDED_BY(mu_);
|
||||
// Index of the first round we plan to serve. At startup, this is the minimum
|
||||
// of all requested element indices.
|
||||
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
|
||||
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
|
||||
bool round_skipped_ TF_GUARDED_BY(mu_) = false;
|
||||
// Buffered results for the current round.
|
||||
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
|
||||
bool end_of_task_ TF_GUARDED_BY(mu_) = false;
|
||||
// Thread which constantly tries to prepare `num_consumers` elements for the
|
||||
// next round.
|
||||
PrefetchThread prefetch_thread_;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -13,6 +13,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
#include "tensorflow/core/data/dataset.pb.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
@ -31,7 +34,10 @@ class TestTaskIterator : public TaskIterator {
|
||||
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
|
||||
end_of_sequence = index_ >= elements_.size();
|
||||
if (!end_of_sequence) {
|
||||
element = elements_[index_];
|
||||
CompressedElement compressed;
|
||||
TF_RETURN_IF_ERROR(CompressElement(elements_[index_], &compressed));
|
||||
element.emplace_back(DT_VARIANT, TensorShape({}));
|
||||
element[0].scalar<Variant>()() = std::move(compressed);
|
||||
index_ = (index_ + 1) % elements_.size();
|
||||
}
|
||||
return Status::OK();
|
||||
@ -47,16 +53,22 @@ class TestTaskIterator : public TaskIterator {
|
||||
// Reads from the task runner, storing results in `*output`.
|
||||
Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index,
|
||||
TaskRunner& task_runner, std::vector<int64>& output) {
|
||||
bool end_of_sequence = false;
|
||||
for (int64 next_index = start_index; next_index < end_index; ++next_index) {
|
||||
TaskRunner::Request request;
|
||||
request.round_index = next_index;
|
||||
request.consumer_index = consumer_index;
|
||||
std::vector<Tensor> element;
|
||||
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
output.push_back(element[0].flat<int64>()(0));
|
||||
}
|
||||
GetElementRequest request;
|
||||
request.set_round_index(next_index);
|
||||
request.set_consumer_index(consumer_index);
|
||||
request.set_skipped_previous_round(false);
|
||||
request.set_allow_skip(false);
|
||||
GetElementResponse response;
|
||||
do {
|
||||
TF_RETURN_IF_ERROR(task_runner.GetNext(request, response));
|
||||
if (!response.end_of_sequence()) {
|
||||
std::vector<Tensor> uncompressed;
|
||||
TF_RETURN_IF_ERROR(
|
||||
UncompressElement(response.compressed_element(), &uncompressed));
|
||||
output.push_back(uncompressed[0].flat<int64>()(0));
|
||||
}
|
||||
} while (response.skip_task());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -71,12 +83,13 @@ TEST(FirstComeFirstServedTaskRunner, GetNext) {
|
||||
}
|
||||
FirstComeFirstServedTaskRunner runner(
|
||||
absl::make_unique<TestTaskIterator>(elements));
|
||||
TaskRunner::Request request;
|
||||
GetElementRequest request;
|
||||
GetElementResponse response;
|
||||
for (auto& expected_element : elements) {
|
||||
TF_ASSERT_OK(runner.GetNext(request, response));
|
||||
ASSERT_FALSE(response.end_of_sequence());
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_sequence;
|
||||
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
|
||||
ASSERT_FALSE(end_of_sequence);
|
||||
TF_ASSERT_OK(UncompressElement(response.compressed_element(), &element));
|
||||
ASSERT_EQ(element.size(), 1);
|
||||
test::ExpectEqual(element[0], expected_element[0]);
|
||||
}
|
||||
|
@ -23,6 +23,11 @@ message GetElementRequest {
|
||||
oneof optional_round_index {
|
||||
int64 round_index = 3;
|
||||
}
|
||||
// Whether the previous round was skipped. This information is needed by the
|
||||
// worker to recover after restarts.
|
||||
bool skipped_previous_round = 4;
|
||||
// Whether to skip the round if data isn't ready fast enough.
|
||||
bool allow_skip = 5;
|
||||
}
|
||||
|
||||
message GetElementResponse {
|
||||
@ -30,6 +35,8 @@ message GetElementResponse {
|
||||
CompressedElement compressed_element = 3;
|
||||
// Boolean to indicate whether the iterator has been exhausted.
|
||||
bool end_of_sequence = 2;
|
||||
// Indicates whether the round was skipped.
|
||||
bool skip_task = 4;
|
||||
}
|
||||
|
||||
// Named GetWorkerTasks to avoid conflicting with GetTasks in dispatcher.proto
|
||||
|
@ -180,8 +180,6 @@ Status DataServiceWorkerImpl::EnsureTaskInitialized(
|
||||
Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
GetElementResponse* response) {
|
||||
VLOG(3) << "Received GetElement request for task " << request->task_id();
|
||||
bool end_of_sequence = false;
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
Task* task;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
@ -207,54 +205,15 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
task = it->second.get();
|
||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||
}
|
||||
TaskRunner::Request get_next_request;
|
||||
if (request->optional_consumer_index_case() ==
|
||||
GetElementRequest::kConsumerIndex) {
|
||||
get_next_request.consumer_index = request->consumer_index();
|
||||
}
|
||||
if (request->optional_round_index_case() == GetElementRequest::kRoundIndex) {
|
||||
get_next_request.round_index = request->round_index();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *response));
|
||||
if (response->end_of_sequence()) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||
pending_completed_tasks_.insert(request->task_id());
|
||||
task_completion_cv_.notify_one();
|
||||
}
|
||||
|
||||
if (!end_of_sequence) {
|
||||
} else if (!response->skip_task()) {
|
||||
VLOG(3) << "Producing an element for task " << request->task_id();
|
||||
if (outputs.size() != 1) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a single scalar variant tensor, but the "
|
||||
"dataset produced ",
|
||||
outputs.size(), " outputs");
|
||||
}
|
||||
if (outputs[0].dtype() != DT_VARIANT) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a single scalar variant tensor, but "
|
||||
"the dataset produced a tensor with type ",
|
||||
DataTypeString(outputs[0].dtype()));
|
||||
}
|
||||
if (!TensorShapeUtils::IsScalar(outputs[0].shape())) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a single scalar variant tensor, but "
|
||||
"the dataset produced a tensor with shape ",
|
||||
outputs[0].shape());
|
||||
}
|
||||
Variant& variant = outputs[0].scalar<Variant>()();
|
||||
CompressedElement* compressed = variant.get<CompressedElement>();
|
||||
if (compressed == nullptr) {
|
||||
return errors::FailedPrecondition(
|
||||
"Expected dataset to produce a CompressedElement variant tensor, but "
|
||||
"it produced ",
|
||||
variant.TypeName());
|
||||
}
|
||||
*response->mutable_compressed_element() = *compressed;
|
||||
}
|
||||
response->set_end_of_sequence(end_of_sequence);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -168,6 +168,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core/data/service:data_service",
|
||||
"//tensorflow/core/data/service:dispatcher_proto_cc",
|
||||
"//tensorflow/core/data/service:grpc_util",
|
||||
"//tensorflow/core/data/service:worker_proto_cc",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:name_utils",
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/data_service.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/worker.pb.h"
|
||||
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/model.h"
|
||||
@ -295,29 +296,37 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
});
|
||||
}
|
||||
|
||||
while ((results_.empty() || !results_.front().ready) &&
|
||||
!(job_finished_ && num_running_worker_threads_ == 0) &&
|
||||
!cancelled_ && status_.ok()) {
|
||||
VLOG(3) << "Blocking in GetNext. results_.size():" << results_.size()
|
||||
<< " results_.front().ready:"
|
||||
<< (!results_.empty() && results_.front().ready)
|
||||
<< " job_finished_:" << job_finished_
|
||||
<< " num_running_worker_threads_:"
|
||||
<< num_running_worker_threads_;
|
||||
get_next_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
VLOG(3) << "Returning from GetNext due to cancellation";
|
||||
return errors::Cancelled("Data service iterator was cancelled");
|
||||
}
|
||||
if (!status_.ok()) {
|
||||
VLOG(3) << "Returning from GetNext with error " << status_;
|
||||
return status_;
|
||||
}
|
||||
if (results_.empty()) {
|
||||
*end_of_sequence = true;
|
||||
VLOG(3) << "Returning from GetNext with end_of_sequence";
|
||||
return Status::OK();
|
||||
bool skip = true;
|
||||
while (skip) {
|
||||
while ((results_.empty() || !results_.front().ready) &&
|
||||
!(job_finished_ && num_running_worker_threads_ == 0) &&
|
||||
!cancelled_ && status_.ok()) {
|
||||
VLOG(3) << "Blocking in GetNext. results_.size():" << results_.size()
|
||||
<< " results_.front().ready:"
|
||||
<< (!results_.empty() && results_.front().ready)
|
||||
<< " job_finished_:" << job_finished_
|
||||
<< " num_running_worker_threads_:"
|
||||
<< num_running_worker_threads_;
|
||||
get_next_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
VLOG(3) << "Returning from GetNext due to cancellation";
|
||||
return errors::Cancelled("Data service iterator was cancelled");
|
||||
}
|
||||
if (!status_.ok()) {
|
||||
VLOG(3) << "Returning from GetNext with error " << status_;
|
||||
return status_;
|
||||
}
|
||||
if (results_.empty()) {
|
||||
*end_of_sequence = true;
|
||||
VLOG(3) << "Returning from GetNext with end_of_sequence";
|
||||
return Status::OK();
|
||||
}
|
||||
skip = results_.front().skip;
|
||||
if (skip) {
|
||||
results_.pop();
|
||||
worker_thread_cv_.notify_one();
|
||||
}
|
||||
}
|
||||
*end_of_sequence = results_.front().end_of_sequence;
|
||||
if (!*end_of_sequence) {
|
||||
@ -378,6 +387,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
const std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
// The next round to read from the task.
|
||||
int64 round = 0;
|
||||
bool skipped_previous_round = false;
|
||||
// Indicates whether a worker thread is currently processing the task.
|
||||
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
// Indicates whether the worker has returned end_of_sequence for the task.
|
||||
@ -390,6 +400,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
|
||||
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
bool skip TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
};
|
||||
|
||||
// Periodically refresh the task list.
|
||||
@ -677,25 +688,24 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
return profiler::TraceMeEncode(
|
||||
{{"address", task->info.worker_address()}});
|
||||
});
|
||||
CompressedElement compressed;
|
||||
bool end_of_sequence;
|
||||
GetElementResponse resp;
|
||||
for (int num_retries = 0;; ++num_retries) {
|
||||
absl::optional<int64> consumer_index = dataset()->consumer_index_;
|
||||
absl::optional<int64> round_index;
|
||||
GetElementRequest req;
|
||||
if (StrictRoundRobin()) {
|
||||
round_index = task->round;
|
||||
req.set_consumer_index(dataset()->consumer_index_.value());
|
||||
req.set_round_index(task->round);
|
||||
req.set_allow_skip(true);
|
||||
VLOG(3) << "Requesting element from consumer index "
|
||||
<< consumer_index.value() << ", round "
|
||||
<< round_index.value();
|
||||
<< req.consumer_index() << ", round " << req.round_index();
|
||||
activity.AppendMetadata([&]() {
|
||||
return profiler::TraceMeEncode(
|
||||
{{"consumer_index", consumer_index.value()},
|
||||
{"round_index", round_index.value()}});
|
||||
{{"consumer_index", req.consumer_index()},
|
||||
{"round_index", req.round_index()}});
|
||||
});
|
||||
}
|
||||
Status s =
|
||||
task->worker->GetElement(task->info.task_id(), consumer_index,
|
||||
round_index, compressed, end_of_sequence);
|
||||
req.set_task_id(task->info.task_id());
|
||||
req.set_skipped_previous_round(task->skipped_previous_round);
|
||||
Status s = task->worker->GetElement(req, resp);
|
||||
if (s.ok()) {
|
||||
break;
|
||||
}
|
||||
@ -709,7 +719,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// If `UpdateTaskThreads` finds that the task has been cancelled, it
|
||||
// will set end_of_sequence to `true`.
|
||||
if (task->end_of_sequence || cancelled_) {
|
||||
end_of_sequence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
@ -734,21 +743,27 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
std::vector<Tensor> element;
|
||||
if (!end_of_sequence) {
|
||||
if (resp.has_compressed_element()) {
|
||||
Tensor tensor(DT_VARIANT, TensorShape{});
|
||||
tensor.scalar<Variant>()() = std::move(compressed);
|
||||
tensor.scalar<Variant>()() = std::move(resp.compressed_element());
|
||||
element.push_back(tensor);
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
result.ready = true;
|
||||
result.end_of_sequence = end_of_sequence;
|
||||
if (end_of_sequence) {
|
||||
result.end_of_sequence = resp.end_of_sequence();
|
||||
if (resp.has_compressed_element()) {
|
||||
task->skipped_previous_round = false;
|
||||
task->round++;
|
||||
result.element = std::move(element);
|
||||
} else if (resp.skip_task()) {
|
||||
task->skipped_previous_round = true;
|
||||
task->round++;
|
||||
result.skip = true;
|
||||
} else {
|
||||
task->end_of_sequence = true;
|
||||
finished_tasks_++;
|
||||
return Status::OK();
|
||||
}
|
||||
result.element = std::move(element);
|
||||
if (enqueue_result) {
|
||||
if (enqueue_result && !resp.end_of_sequence()) {
|
||||
results_.push(std::move(result));
|
||||
}
|
||||
get_next_cv_.notify_all();
|
||||
|
@ -291,8 +291,7 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
cluster = self.create_cluster(num_workers=num_workers)
|
||||
# Round robin reads can cause slow cluster shutdown.
|
||||
data_service_test_base.GLOBAL_CLUSTERS.add(cluster)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = dataset_ops.Dataset.range(10000000)
|
||||
ds = ds.repeat()
|
||||
consumers = []
|
||||
for consumer_index in range(num_consumers):
|
||||
@ -309,18 +308,15 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
lambda x: x,
|
||||
cycle_length=num_consumers,
|
||||
num_parallel_calls=num_consumers)
|
||||
ds = ds.take(2 * num_elements * num_workers)
|
||||
ds = ds.take(1000)
|
||||
results = self.getDatasetOutput(ds, requires_initialization=True)
|
||||
|
||||
expected = []
|
||||
round_index = 0
|
||||
while len(expected) < len(results):
|
||||
for _ in range(num_workers):
|
||||
for consumer in range(num_consumers):
|
||||
expected.append(
|
||||
(round_index * num_consumers + consumer) % num_elements)
|
||||
round_index += 1
|
||||
self.assertEqual(results, expected)
|
||||
for i in range(0, len(results), num_consumers):
|
||||
self.assertEqual(0, results[i] % num_consumers)
|
||||
# Check that each group of `num_consumers` results are consecutive.
|
||||
for offset in range(1, num_consumers):
|
||||
if i + offset < len(results):
|
||||
self.assertEqual(results[i] + offset, results[i + offset])
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRoundRobinBucketizing(self):
|
||||
|
Loading…
Reference in New Issue
Block a user