[tf.data service] Implement round-robin reads.
This enables a new mode of reading from the tf.data service, where consumers read from tasks in a coordinated fashion, instead of the normal first-come first-served. The main use case for this is coordinated bucketization for synchronous training, where we want to ensure that at each step consumers get batches with elements of similar sizes. This mitigates the inefficiency of some consumers slowly training on large examples while others quickly train on small examples, then block waiting for the slower examples to be processed. When `consumer_index` and `num_consumers` are specified to `distribute`, each task will enforce a strict round-robin order, where its first element goes to consumer 0, second element to consumer 1, and so on. This requires that all consumers consume the same number of elements. PiperOrigin-RevId: 351625063 Change-Id: I9b400f55ad61406cb125af8225096e7ff5dc4b0c
This commit is contained in:
parent
dabf3d8057
commit
706350f023
@ -25,6 +25,11 @@
|
||||
gathered at runtime to be used in embedding layer partitioning decisions.
|
||||
* `tf.keras.metrics.AUC` now support logit predictions.
|
||||
* Creating `tf.random.Generator` under `tf.distribute.Strategy` scopes is now allowed (except for `tf.distribute.experimental.CentralStorageStrategy` and `tf.distribute.experimental.ParameterServerStrategy`). Different replicas will get different random-number streams.
|
||||
* `tf.data`:
|
||||
* tf.data service now supports strict round-robin reads, which is useful
|
||||
for synchronous training workloads where example sizes vary. With strict
|
||||
round robin reads, users can guarantee that consumers get similar-sized
|
||||
examples in the same step.
|
||||
|
||||
## Bug Fixes and Other Changes
|
||||
|
||||
|
@ -1,3 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "DataServiceDataset"
|
||||
visibility: HIDDEN
|
||||
summary: "Creates a dataset that reads data from the tf.data service."
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
op {
|
||||
graph_op_name: "DataServiceDataset"
|
||||
graph_op_name: "DataServiceDatasetV2"
|
||||
visibility: HIDDEN
|
||||
summary: "Creates a dataset that reads data from the tf.data service."
|
||||
}
|
@ -87,14 +87,10 @@ cc_library(
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
":dispatcher_cc_grpc_proto",
|
||||
":dispatcher_proto_cc",
|
||||
":grpc_util",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
@ -379,8 +375,11 @@ cc_library(
|
||||
hdrs = ["task_runner.h"],
|
||||
deps = [
|
||||
":common_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/data:compression_utils",
|
||||
"//tensorflow/core/data:standalone",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
],
|
||||
)
|
||||
|
||||
@ -389,11 +388,11 @@ tf_cc_test(
|
||||
srcs = ["task_runner_test.cc"],
|
||||
deps = [
|
||||
":task_runner",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/data:standalone",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
@ -150,7 +151,8 @@ Status DataServiceDispatcherClient::RegisterDataset(GraphDef dataset,
|
||||
|
||||
Status DataServiceDispatcherClient::GetOrCreateJob(
|
||||
int64 dataset_id, ProcessingMode processing_mode,
|
||||
const absl::optional<JobKey>& job_key, int64& job_client_id) {
|
||||
const absl::optional<JobKey>& job_key, absl::optional<int64> num_consumers,
|
||||
int64& job_client_id) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetOrCreateJobRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
@ -158,6 +160,9 @@ Status DataServiceDispatcherClient::GetOrCreateJob(
|
||||
if (job_key.has_value()) {
|
||||
*req.mutable_job_key() = job_key.value();
|
||||
}
|
||||
if (num_consumers.has_value()) {
|
||||
req.set_num_consumers(num_consumers.value());
|
||||
}
|
||||
GetOrCreateJobResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
|
||||
@ -239,11 +244,19 @@ Status DataServiceDispatcherClient::EnsureInitialized() {
|
||||
}
|
||||
|
||||
Status DataServiceWorkerClient::GetElement(int64 task_id,
|
||||
absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element,
|
||||
bool& end_of_sequence) {
|
||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||
GetElementRequest req;
|
||||
req.set_task_id(task_id);
|
||||
if (consumer_index.has_value()) {
|
||||
req.set_consumer_index(consumer_index.value());
|
||||
}
|
||||
if (round_index.has_value()) {
|
||||
req.set_round_index(round_index.value());
|
||||
}
|
||||
GetElementResponse resp;
|
||||
grpc::ClientContext ctx;
|
||||
grpc::Status s = stub_->GetElement(&ctx, req, &resp);
|
||||
|
@ -100,11 +100,12 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
|
||||
// dataset id in `dataset_id`.
|
||||
Status RegisterDataset(GraphDef dataset, int64& dataset_id);
|
||||
|
||||
// Gets the job id for the job represented by the tuple
|
||||
// (job_name, job_name_index), and stores the id in `job_client_id`. If the
|
||||
// job doesn't exist yet, it will be created.
|
||||
// If `job_key` is set, looks up a job matching `job_key`. If `job_key` is
|
||||
// absent or no matching job is found, creates a new job. The resulting job
|
||||
// id is stored in `job_client_id`.
|
||||
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||
const absl::optional<JobKey>& job_key,
|
||||
absl::optional<int64> num_consumers,
|
||||
int64& job_client_id);
|
||||
|
||||
// Releases a job client id, indicating that the id will no longer be used to
|
||||
@ -138,11 +139,14 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
||||
const std::string& protocol)
|
||||
: DataServiceClientBase(address, protocol) {}
|
||||
|
||||
// Fetches the next element for the specified task_id. The element's
|
||||
// compressed tensors will be stored in `element`. If no element is available,
|
||||
// `end_of_sequence` will be `true`, and `element` will be left unchanged.
|
||||
Status GetElement(int64 task_id, CompressedElement& element,
|
||||
bool& end_of_sequence);
|
||||
// Fetches the next element for the specified task_id. The optional
|
||||
// `consumer_index` and `round_index` must be specified for tasks which use
|
||||
// round-robin ordering. The element's compressed tensors will be stored in
|
||||
// `element`. If no element is available, `end_of_sequence` will be `true`,
|
||||
// and `element` will be left unchanged.
|
||||
Status GetElement(int64 task_id, absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> round_index,
|
||||
CompressedElement& element, bool& end_of_sequence);
|
||||
|
||||
protected:
|
||||
Status EnsureInitialized() override;
|
||||
|
@ -533,7 +533,11 @@ Status DataServiceDispatcherImpl::CreateTasksForWorker(
|
||||
const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
std::vector<std::shared_ptr<const Job>> jobs = state_.ListJobs();
|
||||
for (const auto& job : jobs) {
|
||||
if (job->finished) {
|
||||
if (job->finished || job->num_consumers.has_value()) {
|
||||
// Don't add new tasks for late-joining workers when doing round-robin
|
||||
// reads. It would create synchronization issues where some clients might
|
||||
// learn about the new tasks earlier than others, potentially causing
|
||||
// deadlock.
|
||||
continue;
|
||||
}
|
||||
std::shared_ptr<const Task> task;
|
||||
@ -646,6 +650,9 @@ Status DataServiceDispatcherImpl::AssignTask(std::shared_ptr<const Task> task)
|
||||
}
|
||||
task_def->set_task_id(task->task_id);
|
||||
task_def->set_processing_mode(ProcessingModeDef(task->job->processing_mode));
|
||||
if (task->job->num_consumers.has_value()) {
|
||||
task_def->set_num_consumers(task->job->num_consumers.value());
|
||||
}
|
||||
ProcessTaskResponse resp;
|
||||
WorkerService::Stub* stub;
|
||||
TF_RETURN_IF_ERROR(GetOrCreateWorkerStub(task->worker_address, stub));
|
||||
|
@ -15,12 +15,21 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
|
||||
#include "tensorflow/core/data/compression_utils.h"
|
||||
#include "tensorflow/core/data/standalone.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_util.h"
|
||||
#include "tensorflow/core/lib/gtl/cleanup.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
// How long to wait for other round-robin consumers before returning with an
|
||||
// Unavailable error. The unavailable error gives the client an opportunity to
|
||||
// either give up or retry to continue waiting.
|
||||
const int64 kDefaultTimeoutUs = 2 * 1000 * 1000; // 2 seconds.
|
||||
} // namespace
|
||||
|
||||
StandaloneTaskIterator::StandaloneTaskIterator(
|
||||
std::unique_ptr<standalone::Dataset> dataset,
|
||||
@ -32,12 +41,25 @@ Status StandaloneTaskIterator::GetNext(std::vector<Tensor>& element,
|
||||
return iterator_->GetNext(&element, &end_of_sequence);
|
||||
}
|
||||
|
||||
int64 StandaloneTaskIterator::Cardinality() const {
|
||||
return dataset_->Get()->Cardinality();
|
||||
}
|
||||
|
||||
Status TaskRunner::Create(const TaskDef& task_def,
|
||||
std::unique_ptr<TaskIterator> iterator,
|
||||
std::unique_ptr<TaskRunner>& out) {
|
||||
if (task_def.optional_num_consumers_case() == TaskDef::kNumConsumers) {
|
||||
out = absl::make_unique<RoundRobinTaskRunner>(std::move(iterator),
|
||||
task_def.num_consumers());
|
||||
int64 cardinality = iterator->Cardinality();
|
||||
if (cardinality != kInfiniteCardinality &&
|
||||
cardinality != kUnknownCardinality) {
|
||||
return errors::FailedPrecondition(
|
||||
"Round robin reads require that the input dataset has infinite "
|
||||
"cardinality, but the dataset has cardinality ",
|
||||
cardinality,
|
||||
". Consider adding a `.repeat()` transformation to the dataset.");
|
||||
}
|
||||
out = absl::make_unique<RoundRobinTaskRunner>(
|
||||
std::move(iterator), task_def.num_consumers(), kDefaultTimeoutUs);
|
||||
} else {
|
||||
out =
|
||||
absl::make_unique<FirstComeFirstServedTaskRunner>(std::move(iterator));
|
||||
@ -56,10 +78,15 @@ Status FirstComeFirstServedTaskRunner::GetNext(const Request& request,
|
||||
}
|
||||
|
||||
RoundRobinTaskRunner::RoundRobinTaskRunner(
|
||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers)
|
||||
std::unique_ptr<TaskIterator> iterator, int64 num_consumers,
|
||||
int64 timeout_us)
|
||||
: num_consumers_(num_consumers),
|
||||
timeout_us_(timeout_us),
|
||||
iterator_(std::move(iterator)),
|
||||
buffer_(num_consumers_) {}
|
||||
buffer_(num_consumers_) {
|
||||
VLOG(1) << "Creating task runner for distributing data round-robin to "
|
||||
<< num_consumers << " consumers";
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::GetNext(const Request& request,
|
||||
std::vector<Tensor>& element,
|
||||
@ -74,12 +101,15 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
|
||||
"Requesting data for consumer index ", request.consumer_index,
|
||||
", but the task is configured for only ", num_consumers_, " consumers");
|
||||
}
|
||||
VLOG(2) << "Received request from consumer index " << request.consumer_index
|
||||
<< " for round " << request.round_index;
|
||||
|
||||
mutex_lock l(mu_);
|
||||
absl::flat_hash_set<int64>& round = requests_[request.round_index];
|
||||
first_round_ = std::min(first_round_, request.round_index);
|
||||
round.insert(request.consumer_index);
|
||||
if (current_round_ < request.round_index && round.size() == num_consumers_) {
|
||||
VLOG(1) << "Starting normal round with round index " << request.round_index;
|
||||
// This was the last request to arrive, time to start a new round.
|
||||
TF_RETURN_IF_ERROR(FillBuffer());
|
||||
current_round_ = request.round_index;
|
||||
@ -88,39 +118,44 @@ Status RoundRobinTaskRunner::GetNext(const Request& request,
|
||||
if (current_round_ < 0 &&
|
||||
requests_[first_round_].size() + requests_[first_round_ + 1].size() ==
|
||||
num_consumers_) {
|
||||
VLOG(1) << "Starting partial round for " << requests_[first_round_].size()
|
||||
<< " consumers";
|
||||
// Indicates that we need a partial round to get consumers back in sync.
|
||||
TF_RETURN_IF_ERROR(FillBuffer());
|
||||
current_round_ = first_round_;
|
||||
new_round_cv_.notify_all();
|
||||
}
|
||||
while (current_round_ < request.round_index) {
|
||||
new_round_cv_.wait(l);
|
||||
std::cv_status s =
|
||||
new_round_cv_.wait_for(l, std::chrono::microseconds(timeout_us_));
|
||||
if (s == std::cv_status::timeout) {
|
||||
// Clients will retry Unavailable.
|
||||
return errors::Unavailable(
|
||||
"Timeout waiting for other round-robin consumers to be ready.");
|
||||
}
|
||||
}
|
||||
Result& result = buffer_[request.consumer_index];
|
||||
end_of_task = result.end_of_task;
|
||||
end_of_task = end_of_task_;
|
||||
if (!end_of_task) {
|
||||
element = std::move(result.element);
|
||||
element.clear();
|
||||
for (auto& component : buffer_[request.consumer_index]) {
|
||||
element.push_back(tensor::DeepCopy(component));
|
||||
}
|
||||
}
|
||||
VLOG(2) << "Returning to consumer " << request.consumer_index << " for round "
|
||||
<< request.round_index;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RoundRobinTaskRunner::FillBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int i = 0; i < num_consumers_; ++i) {
|
||||
Result& result = buffer_[i];
|
||||
result.element.clear();
|
||||
TF_RETURN_IF_ERROR(iterator_->GetNext(result.element, result.end_of_task));
|
||||
if (buffer_[i].end_of_task && !buffer_[0].end_of_task) {
|
||||
std::vector<Tensor>& first_element = buffer_[0].element;
|
||||
// Pad out the round with empty elements.
|
||||
buffer_[i].element.clear();
|
||||
for (int c = 0; c < first_element.size(); ++c) {
|
||||
TensorShape shape = first_element[c].shape();
|
||||
if (shape.dims() > 0) {
|
||||
shape.set_dim(0, 0);
|
||||
}
|
||||
buffer_[i].element.push_back(Tensor(first_element[c].dtype(), shape));
|
||||
}
|
||||
buffer_[i].end_of_task = false;
|
||||
buffer_[i].clear();
|
||||
bool end_of_sequence;
|
||||
TF_RETURN_IF_ERROR(iterator_->GetNext(buffer_[i], end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
return errors::FailedPrecondition(
|
||||
"Encountered end of sequence on a round-robin read iterator. Please "
|
||||
"ensure that the dataset used for round-robin reading has infinite "
|
||||
"cardinality, e.g. by adding a .repeat() transformation at the end.");
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -31,6 +31,8 @@ class TaskIterator {
|
||||
// `end_of_sequence to `true`.
|
||||
virtual Status GetNext(std::vector<Tensor>& element,
|
||||
bool& end_of_sequence) = 0;
|
||||
// Reports the cardinality of the dataset that created this iterator.
|
||||
virtual int64 Cardinality() const = 0;
|
||||
};
|
||||
|
||||
// Implementation of TaskIterator wrapping a standalone iterator.
|
||||
@ -42,6 +44,7 @@ class StandaloneTaskIterator : public TaskIterator {
|
||||
StandaloneTaskIterator(std::unique_ptr<standalone::Dataset> dataset,
|
||||
std::unique_ptr<standalone::Iterator> iterator);
|
||||
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override;
|
||||
int64 Cardinality() const override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<standalone::Dataset> dataset_;
|
||||
@ -109,19 +112,16 @@ class FirstComeFirstServedTaskRunner : public TaskRunner {
|
||||
class RoundRobinTaskRunner : public TaskRunner {
|
||||
public:
|
||||
RoundRobinTaskRunner(std::unique_ptr<TaskIterator> iterator,
|
||||
int64 num_consumers);
|
||||
int64 num_consumers, int64 timeout_us);
|
||||
Status GetNext(const Request& request, std::vector<Tensor>& element,
|
||||
bool& end_of_task) override;
|
||||
|
||||
private:
|
||||
struct Result {
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_task = false;
|
||||
};
|
||||
// Fills `buffer_` with `num_consumers_` elements.
|
||||
Status FillBuffer();
|
||||
|
||||
const int64 num_consumers_;
|
||||
const int64 timeout_us_;
|
||||
std::unique_ptr<TaskIterator> iterator_;
|
||||
mutex mu_;
|
||||
// Condition variable notified whenever we start a new round of round-robin.
|
||||
@ -134,7 +134,8 @@ class RoundRobinTaskRunner : public TaskRunner {
|
||||
int64 first_round_ TF_GUARDED_BY(mu_) = kint64max;
|
||||
int64 current_round_ TF_GUARDED_BY(mu_) = -1;
|
||||
// Buffered results for the current round.
|
||||
std::vector<Result> buffer_ TF_GUARDED_BY(mu_);
|
||||
std::vector<std::vector<Tensor>> buffer_ TF_GUARDED_BY(mu_);
|
||||
bool end_of_task_ TF_GUARDED_BY(mu_) = false;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
|
@ -13,13 +13,17 @@ limitations under the License.
|
||||
#include "tensorflow/core/data/service/task_runner.h"
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
namespace {
|
||||
const int64 kNoTimeoutUs = 60ull * 60 * 1000 * 1000; // 60 minutes.
|
||||
|
||||
class TestTaskIterator : public TaskIterator {
|
||||
public:
|
||||
explicit TestTaskIterator(const std::vector<std::vector<Tensor>>& elements)
|
||||
@ -28,30 +32,31 @@ class TestTaskIterator : public TaskIterator {
|
||||
Status GetNext(std::vector<Tensor>& element, bool& end_of_sequence) override {
|
||||
end_of_sequence = index_ >= elements_.size();
|
||||
if (!end_of_sequence) {
|
||||
element = elements_[index_++];
|
||||
element = elements_[index_];
|
||||
index_ = (index_ + 1) % elements_.size();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
int64 Cardinality() const override { return kInfiniteCardinality; }
|
||||
|
||||
private:
|
||||
std::vector<std::vector<Tensor>> elements_;
|
||||
int64 index_;
|
||||
};
|
||||
|
||||
// Reads from the task runner, storing results in `*output`.
|
||||
Status RunConsumer(int64 consumer_index, int64 start_index,
|
||||
TaskRunner& task_runner,
|
||||
std::vector<std::vector<Tensor>>& output) {
|
||||
Status RunConsumer(int64 consumer_index, int64 start_index, int64 end_index,
|
||||
TaskRunner& task_runner, std::vector<int64>& output) {
|
||||
bool end_of_sequence = false;
|
||||
int64 next_index = start_index;
|
||||
while (!end_of_sequence) {
|
||||
for (int64 next_index = start_index; next_index < end_index; ++next_index) {
|
||||
TaskRunner::Request request;
|
||||
request.round_index = next_index++;
|
||||
request.round_index = next_index;
|
||||
request.consumer_index = consumer_index;
|
||||
std::vector<Tensor> element;
|
||||
TF_RETURN_IF_ERROR(task_runner.GetNext(request, element, end_of_sequence));
|
||||
if (!end_of_sequence) {
|
||||
output.push_back(element);
|
||||
output.push_back(element[0].flat<int64>()(0));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -76,12 +81,6 @@ TEST(FirstComeFirstServedTaskRunner, GetNext) {
|
||||
ASSERT_EQ(element.size(), 1);
|
||||
test::ExpectEqual(element[0], expected_element[0]);
|
||||
}
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
bool end_of_sequence;
|
||||
TF_ASSERT_OK(runner.GetNext(request, element, end_of_sequence));
|
||||
ASSERT_TRUE(end_of_sequence);
|
||||
}
|
||||
}
|
||||
|
||||
class ConsumeParallelTest
|
||||
@ -98,8 +97,8 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) {
|
||||
elements.push_back(element);
|
||||
}
|
||||
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
|
||||
num_consumers);
|
||||
std::vector<std::vector<std::vector<Tensor>>> per_consumer_results;
|
||||
num_consumers, kNoTimeoutUs);
|
||||
std::vector<std::vector<int64>> per_consumer_results;
|
||||
std::vector<std::unique_ptr<Thread>> consumers;
|
||||
mutex mu;
|
||||
Status error;
|
||||
@ -108,8 +107,9 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) {
|
||||
per_consumer_results.emplace_back();
|
||||
consumers.push_back(absl::WrapUnique(Env::Default()->StartThread(
|
||||
{}, absl::StrCat("consumer_", consumer), [&, consumer] {
|
||||
std::vector<std::vector<Tensor>> results;
|
||||
Status s = RunConsumer(consumer, /*start_index=*/0, runner, results);
|
||||
std::vector<int64> results;
|
||||
Status s = RunConsumer(consumer, /*start_index=*/0,
|
||||
/*end_index=*/num_elements, runner, results);
|
||||
mutex_lock l(mu);
|
||||
if (!s.ok()) {
|
||||
error = s;
|
||||
@ -125,8 +125,7 @@ TEST_P(ConsumeParallelTest, ConsumeParallel) {
|
||||
for (int i = 0; i < num_elements; ++i) {
|
||||
int consumer = i % num_consumers;
|
||||
int round = i / num_consumers;
|
||||
Tensor expected = elements[i][0];
|
||||
test::ExpectEqual(per_consumer_results[consumer][round][0], expected);
|
||||
EXPECT_EQ(per_consumer_results[consumer][round], i);
|
||||
}
|
||||
}
|
||||
|
||||
@ -139,19 +138,20 @@ INSTANTIATE_TEST_SUITE_P(ConsumeParallelTests, ConsumeParallelTest,
|
||||
std::make_tuple(0, 20)));
|
||||
|
||||
TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
|
||||
int64 num_elements = 20;
|
||||
int64 num_consumers = 5;
|
||||
std::vector<int64> starting_rounds = {12, 11, 11, 12, 12};
|
||||
int64 min_starting_round = 11;
|
||||
int64 end_index = 15;
|
||||
std::vector<std::vector<int64>> expected_consumer_results = {
|
||||
{5, 10, 15}, {1, 6, 11, 16}, {2, 7, 12, 17}, {8, 13, 18}, {9, 14, 19}};
|
||||
std::vector<std::vector<Tensor>> elements;
|
||||
for (int64 i = 0; i < num_elements; ++i) {
|
||||
for (int64 i = 0; i < 30; ++i) {
|
||||
std::vector<Tensor> element;
|
||||
element.push_back(Tensor(i));
|
||||
elements.push_back(element);
|
||||
}
|
||||
RoundRobinTaskRunner runner(absl::make_unique<TestTaskIterator>(elements),
|
||||
num_consumers);
|
||||
std::vector<std::vector<std::vector<Tensor>>> per_consumer_results;
|
||||
num_consumers, kNoTimeoutUs);
|
||||
std::vector<std::vector<int64>> per_consumer_results;
|
||||
std::vector<std::unique_ptr<Thread>> consumers;
|
||||
mutex mu;
|
||||
Status error;
|
||||
@ -160,9 +160,9 @@ TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
|
||||
per_consumer_results.emplace_back();
|
||||
consumers.push_back(absl::WrapUnique(Env::Default()->StartThread(
|
||||
{}, absl::StrCat("consumer_", consumer), [&, consumer] {
|
||||
std::vector<std::vector<Tensor>> results;
|
||||
Status s =
|
||||
RunConsumer(consumer, starting_rounds[consumer], runner, results);
|
||||
std::vector<int64> results;
|
||||
Status s = RunConsumer(consumer, starting_rounds[consumer], end_index,
|
||||
runner, results);
|
||||
mutex_lock l(mu);
|
||||
if (!s.ok()) {
|
||||
error = s;
|
||||
@ -176,19 +176,8 @@ TEST(RoundRobinTaskRunner, ConsumeParallelPartialRound) {
|
||||
mutex_lock l(mu);
|
||||
TF_ASSERT_OK(error);
|
||||
for (int consumer = 0; consumer < num_consumers; ++consumer) {
|
||||
auto& results = per_consumer_results[consumer];
|
||||
int start = consumer;
|
||||
int expected_elements = num_elements / num_consumers;
|
||||
if (starting_rounds[consumer] != min_starting_round) {
|
||||
start += num_consumers;
|
||||
expected_elements--;
|
||||
}
|
||||
ASSERT_EQ(results.size(), expected_elements);
|
||||
int index = 0;
|
||||
for (int i = start; i < num_elements; i += num_consumers) {
|
||||
Tensor expected = elements[i][0];
|
||||
test::ExpectEqual(results[index++][0], expected);
|
||||
}
|
||||
EXPECT_EQ(per_consumer_results[consumer],
|
||||
expected_consumer_results[consumer]);
|
||||
}
|
||||
}
|
||||
} // namespace data
|
||||
|
@ -180,6 +180,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
VLOG(3) << "Received GetElement request for task " << request->task_id();
|
||||
bool end_of_sequence = false;
|
||||
std::vector<tensorflow::Tensor> outputs;
|
||||
Task* task;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (!registered_) {
|
||||
@ -201,24 +202,24 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
return errors::Unavailable("Task ", request->task_id(), " not found");
|
||||
}
|
||||
}
|
||||
auto& task = it->second;
|
||||
task = it->second.get();
|
||||
TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
|
||||
TaskRunner::Request get_next_request;
|
||||
if (request->optional_consumer_index_case() ==
|
||||
GetElementRequest::kConsumerIndex) {
|
||||
get_next_request.consumer_index = request->consumer_index();
|
||||
}
|
||||
if (request->optional_round_index_case() ==
|
||||
GetElementRequest::kRoundIndex) {
|
||||
get_next_request.round_index = request->round_index();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||
pending_completed_tasks_.insert(request->task_id());
|
||||
task_completion_cv_.notify_one();
|
||||
}
|
||||
}
|
||||
TaskRunner::Request get_next_request;
|
||||
if (request->optional_consumer_index_case() ==
|
||||
GetElementRequest::kConsumerIndex) {
|
||||
get_next_request.consumer_index = request->consumer_index();
|
||||
}
|
||||
if (request->optional_round_index_case() == GetElementRequest::kRoundIndex) {
|
||||
get_next_request.round_index = request->round_index();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
task->task_runner->GetNext(get_next_request, outputs, end_of_sequence));
|
||||
if (end_of_sequence) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||
pending_completed_tasks_.insert(request->task_id());
|
||||
task_completion_cv_.notify_one();
|
||||
}
|
||||
|
||||
if (!end_of_sequence) {
|
||||
@ -249,7 +250,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
||||
"it produced ",
|
||||
variant.TypeName());
|
||||
}
|
||||
compressed->Swap(response->mutable_compressed_element());
|
||||
*response->mutable_compressed_element() = *compressed;
|
||||
}
|
||||
response->set_end_of_sequence(end_of_sequence);
|
||||
|
||||
|
@ -139,6 +139,8 @@ Status Dataset::MakeSplitProvider(std::unique_ptr<SplitProvider>* result) {
|
||||
return dataset_->MakeSplitProvider(result);
|
||||
}
|
||||
|
||||
const DatasetBase* Dataset::Get() const { return dataset_; }
|
||||
|
||||
Dataset::Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
|
||||
ProcessFunctionLibraryRuntime* pflr,
|
||||
FunctionLibraryDefinition* flib_def, thread::ThreadPool* pool)
|
||||
|
@ -104,6 +104,8 @@ class Dataset {
|
||||
|
||||
// Creates a split provider for this dataset.
|
||||
Status MakeSplitProvider(std::unique_ptr<SplitProvider>* result);
|
||||
// Returns a pointer to the underlying dataset.
|
||||
const DatasetBase* Get() const;
|
||||
|
||||
private:
|
||||
Dataset(DatasetBase* dataset, DeviceMgr* device_mgr,
|
||||
|
@ -51,6 +51,8 @@ namespace data {
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kConsumerIndex;
|
||||
/* static */ constexpr const char* const DataServiceDatasetOp::kNumConsumers;
|
||||
/* static */ constexpr const char* const
|
||||
DataServiceDatasetOp::kMaxOutstandingRequests;
|
||||
/* static */ constexpr const char* const
|
||||
@ -63,6 +65,9 @@ namespace data {
|
||||
namespace {
|
||||
// Default interval between task list refreshes.
|
||||
const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
||||
|
||||
constexpr char kDataServiceDatasetV1[] = "DataServiceDataset";
|
||||
constexpr char kDataServiceDatasetV2[] = "DataServiceDatasetV2";
|
||||
} // namespace
|
||||
|
||||
// Dataset for reading data from the tf.data service non-deterministically.
|
||||
@ -72,20 +77,24 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
||||
// to read from (in case workers are added or removed).
|
||||
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
public:
|
||||
Dataset(OpKernelContext* ctx, int64 dataset_id,
|
||||
Dataset(OpKernelContext* ctx, int op_version, int64 dataset_id,
|
||||
ProcessingMode processing_mode, const std::string& address,
|
||||
const std::string& protocol, const std::string& job_name,
|
||||
int64 max_outstanding_requests, int64 task_refresh_interval_ms,
|
||||
IterationCounter* iteration_counter, bool owns_resource,
|
||||
ResourceHandle iteration_counter_handle,
|
||||
absl::optional<int64> consumer_index,
|
||||
absl::optional<int64> num_consumers, int64 max_outstanding_requests,
|
||||
int64 task_refresh_interval_ms, IterationCounter* iteration_counter,
|
||||
bool owns_resource, ResourceHandle iteration_counter_handle,
|
||||
const DataTypeVector& output_types,
|
||||
const std::vector<PartialTensorShape>& output_shapes)
|
||||
: DatasetBase(DatasetContext(ctx)),
|
||||
op_version_(op_version),
|
||||
dataset_id_(dataset_id),
|
||||
processing_mode_(processing_mode),
|
||||
address_(address),
|
||||
protocol_(protocol),
|
||||
job_name_(job_name),
|
||||
consumer_index_(consumer_index),
|
||||
num_consumers_(num_consumers),
|
||||
max_outstanding_requests_(max_outstanding_requests),
|
||||
task_refresh_interval_ms_(task_refresh_interval_ms),
|
||||
iteration_counter_(iteration_counter),
|
||||
@ -135,39 +144,58 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||
DatasetGraphDefBuilder* b,
|
||||
Node** output) const override {
|
||||
std::vector<Node*> inputs;
|
||||
|
||||
Node* dataset_id;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id));
|
||||
inputs.push_back(dataset_id);
|
||||
|
||||
Node* processing_mode;
|
||||
tstring processing_mode_str = ProcessingModeToString(processing_mode_);
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode));
|
||||
inputs.push_back(processing_mode);
|
||||
|
||||
Node* address;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
|
||||
inputs.push_back(address);
|
||||
|
||||
Node* protocol;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
|
||||
inputs.push_back(protocol);
|
||||
|
||||
Node* job_name;
|
||||
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
|
||||
inputs.push_back(job_name);
|
||||
|
||||
if (op_version_ == 2) {
|
||||
Node* consumer_index;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(consumer_index_.value_or(-1), &consumer_index));
|
||||
inputs.push_back(consumer_index);
|
||||
|
||||
Node* num_consumers;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(num_consumers_.value_or(-1), &num_consumers));
|
||||
inputs.push_back(num_consumers);
|
||||
}
|
||||
|
||||
Node* max_outstanding_requests;
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
|
||||
inputs.push_back(max_outstanding_requests);
|
||||
|
||||
Node* iteration_counter_handle = nullptr;
|
||||
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||
handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
|
||||
TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
|
||||
inputs.push_back(iteration_counter_handle);
|
||||
|
||||
AttrValue task_refresh_interval_hint_ms;
|
||||
b->BuildAttrValue(task_refresh_interval_ms_,
|
||||
&task_refresh_interval_hint_ms);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
b->AddDataset(this,
|
||||
{dataset_id, processing_mode, address, protocol, job_name,
|
||||
max_outstanding_requests, iteration_counter_handle},
|
||||
b->AddDataset(this, inputs,
|
||||
{std::make_pair(kTaskRefreshIntervalHintMs,
|
||||
task_refresh_interval_hint_ms)},
|
||||
output));
|
||||
@ -205,7 +233,6 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
void CancelThreads() TF_LOCKS_EXCLUDED(mu_) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(1) << "Cancelling threads in DataServiceDataset::Iterator";
|
||||
cancelled_ = true;
|
||||
worker_thread_cv_.notify_all();
|
||||
manager_thread_cv_.notify_all();
|
||||
@ -229,9 +256,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
TF_RETURN_IF_ERROR(grpc_util::Retry(
|
||||
[&]() {
|
||||
return dispatcher_->GetOrCreateJob(dataset()->dataset_id_,
|
||||
dataset()->processing_mode_, key,
|
||||
job_client_id_);
|
||||
return dispatcher_->GetOrCreateJob(
|
||||
dataset()->dataset_id_, dataset()->processing_mode_, key,
|
||||
dataset()->num_consumers_, job_client_id_);
|
||||
},
|
||||
/*description=*/
|
||||
strings::StrCat("get or create job with dispatcher at ",
|
||||
@ -254,27 +281,38 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
});
|
||||
}
|
||||
|
||||
while (results_.empty() &&
|
||||
while ((results_.empty() || !results_.front().ready) &&
|
||||
!(job_finished_ && num_running_worker_threads_ == 0) &&
|
||||
!cancelled_ && status_.ok()) {
|
||||
VLOG(3) << "Blocking in GetNext. results_.size():" << results_.size()
|
||||
<< " results_.front().ready:"
|
||||
<< (!results_.empty() && results_.front().ready)
|
||||
<< " job_finished_:" << job_finished_
|
||||
<< " num_running_worker_threads_:"
|
||||
<< num_running_worker_threads_;
|
||||
get_next_cv_.wait(l);
|
||||
}
|
||||
if (cancelled_) {
|
||||
VLOG(3) << "Returning from GetNext due to cancellation";
|
||||
return errors::Cancelled("Data service iterator was cancelled");
|
||||
}
|
||||
if (!status_.ok()) {
|
||||
VLOG(3) << "Returning from GetNext with error " << status_;
|
||||
return status_;
|
||||
}
|
||||
if (results_.empty()) {
|
||||
*end_of_sequence = true;
|
||||
VLOG(3) << "Returning from GetNext with end_of_sequence";
|
||||
return Status::OK();
|
||||
}
|
||||
DCHECK(!results_.empty());
|
||||
*end_of_sequence = false;
|
||||
out_tensors->swap(results_.front());
|
||||
*end_of_sequence = results_.front().end_of_sequence;
|
||||
if (!*end_of_sequence) {
|
||||
out_tensors->swap(results_.front().element);
|
||||
}
|
||||
results_.pop();
|
||||
worker_thread_cv_.notify_one();
|
||||
|
||||
VLOG(3) << "Returning from GetNext with an element";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -326,12 +364,22 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
const std::string address;
|
||||
// Client for fetching task elements from the tf.data service worker.
|
||||
const std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
// Number of elements read by the task.
|
||||
int64 elements_read = 0;
|
||||
// Indicates whether a worker thread is currently processing the task.
|
||||
bool in_use TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
// Indicates whether the worker has returned end_of_sequence for the task.
|
||||
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
};
|
||||
|
||||
struct Result {
|
||||
// Whether the result has been computed yet. GetNext needs to block
|
||||
// until the next result is ready.
|
||||
bool ready TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
std::vector<Tensor> element TF_GUARDED_BY(&Iterator::mu_);
|
||||
bool end_of_sequence TF_GUARDED_BY(&Iterator::mu_) = false;
|
||||
};
|
||||
|
||||
// Periodically refresh the task list.
|
||||
// Maintain one thread fetching elements for each task.
|
||||
// TODO(aaudibert): Instead of polling, have dispatcher send updates when
|
||||
@ -347,7 +395,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// All units are microseconds.
|
||||
while (!cancelled_ && Env::Default()->NowMicros() < next_check) {
|
||||
int64 remaining_time = next_check - Env::Default()->NowMicros();
|
||||
VLOG(3) << "Task thread manager waiting for " << remaining_time
|
||||
VLOG(4) << "Task thread manager waiting for " << remaining_time
|
||||
<< "us";
|
||||
manager_thread_cv_.wait_for(
|
||||
l, std::chrono::microseconds(remaining_time));
|
||||
@ -365,7 +413,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
void UpdateTasks() TF_LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Updating tasks";
|
||||
VLOG(4) << "Updating tasks";
|
||||
std::vector<TaskInfo> tasks;
|
||||
bool job_finished;
|
||||
Status s = dispatcher_->GetTasks(job_client_id_, tasks, job_finished);
|
||||
@ -400,8 +448,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
tasks_.pop_back();
|
||||
}
|
||||
}
|
||||
for (auto& new_task_entry : task_id_to_task) {
|
||||
TaskInfo& task_info = new_task_entry.second;
|
||||
for (auto& task : tasks) {
|
||||
auto it = task_id_to_task.find(task.task_id());
|
||||
if (it == task_id_to_task.end()) {
|
||||
continue;
|
||||
}
|
||||
TaskInfo& task_info = it->second;
|
||||
std::unique_ptr<DataServiceWorkerClient> worker;
|
||||
Status s = CreateDataServiceWorkerClient(task_info.worker_address(),
|
||||
dataset()->protocol_, worker);
|
||||
@ -446,6 +498,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
VLOG(1) << "Starting worker thread";
|
||||
std::shared_ptr<Task> task_to_process;
|
||||
while (true) {
|
||||
Result* result;
|
||||
{
|
||||
mutex_lock l(mu_);
|
||||
if (task_to_process) {
|
||||
@ -454,7 +507,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
worker_thread_cv_.notify_one();
|
||||
}
|
||||
outstanding_requests_--;
|
||||
while (!cancelled_ && !(SpaceInBuffer() && TaskAvailable()) &&
|
||||
while (!cancelled_ && !(ElementSpaceAvailable() && TaskAvailable()) &&
|
||||
!job_finished_) {
|
||||
if (VLOG_IS_ON(3)) {
|
||||
VLOG(3) << "Sleeping with results_.size=" << results_.size()
|
||||
@ -470,23 +523,40 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
if (cancelled_ || job_finished_) {
|
||||
return;
|
||||
}
|
||||
// Search for a task to update.
|
||||
int num_tasks = tasks_.size();
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
int index = (next_task_index_ + i) % num_tasks;
|
||||
std::shared_ptr<Task>& task = tasks_[index];
|
||||
if (!task->in_use && !task->end_of_sequence) {
|
||||
task->in_use = true;
|
||||
task_to_process = task;
|
||||
next_task_index_ = (index + 1) % num_tasks;
|
||||
break;
|
||||
if (StrictRoundRobin()) {
|
||||
task_to_process = tasks_[next_task_index_];
|
||||
// Reserve a spot in the results_ queue.
|
||||
results_.emplace();
|
||||
result = &results_.back();
|
||||
next_task_index_ = (next_task_index_ + 1) % tasks_.size();
|
||||
DCHECK(!task_to_process->in_use);
|
||||
} else {
|
||||
// Search for a task to update.
|
||||
int num_tasks = tasks_.size();
|
||||
for (int i = 0; i < num_tasks; ++i) {
|
||||
int index = (next_task_index_ + i) % num_tasks;
|
||||
std::shared_ptr<Task>& task = tasks_[index];
|
||||
if (!task->in_use && !task->end_of_sequence) {
|
||||
task_to_process = task;
|
||||
next_task_index_ = (index + 1) % num_tasks;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
DCHECK(task_to_process != nullptr);
|
||||
task_to_process->in_use = true;
|
||||
VLOG(3) << "Processing task " << task_to_process->task_id;
|
||||
}
|
||||
int64 deadline_micros = kint64max;
|
||||
Status s = GetElement(task_to_process.get(), deadline_micros);
|
||||
Status s;
|
||||
if (StrictRoundRobin()) {
|
||||
s = GetElement(task_to_process.get(), deadline_micros,
|
||||
/*enqueue_result=*/false, *result);
|
||||
} else {
|
||||
Result r;
|
||||
s = GetElement(task_to_process.get(), deadline_micros,
|
||||
/*enqueue_result=*/true, r);
|
||||
}
|
||||
if (!s.ok()) {
|
||||
mutex_lock l(mu_);
|
||||
VLOG(1) << "Failed to get element from worker "
|
||||
@ -502,21 +572,28 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
// Gets an element from a task and adds the element to `results_`.
|
||||
//
|
||||
// If the task reaches end_of_sequence or is cancelled (e.g. due to a
|
||||
// worker dying), GetElement returns Status::OK() without adding to
|
||||
// `results_`.
|
||||
Status GetElement(Task* task, int64 deadline_micros)
|
||||
TF_LOCKS_EXCLUDED(mu_) {
|
||||
// Gets an element from a task and stores the element in `result`. If
|
||||
// `enqueue_result` is true, `GetElement` also enqueues (via std::move) any
|
||||
// element-producing result in the `results_` queue.
|
||||
Status GetElement(Task* task, int64 deadline_micros, bool enqueue_result,
|
||||
Result& result) TF_LOCKS_EXCLUDED(mu_) {
|
||||
VLOG(3) << "Getting an element for task id " << task->task_id;
|
||||
tensorflow::profiler::TraceMe activity(
|
||||
"GetDataServiceElement", tensorflow::profiler::TraceMeLevel::kInfo);
|
||||
CompressedElement compressed;
|
||||
bool end_of_sequence;
|
||||
for (int num_retries = 0;; ++num_retries) {
|
||||
Status s = task->worker->GetElement(task->task_id, compressed,
|
||||
end_of_sequence);
|
||||
absl::optional<int64> consumer_index = dataset()->consumer_index_;
|
||||
absl::optional<int64> round_index;
|
||||
if (StrictRoundRobin()) {
|
||||
round_index = task->elements_read;
|
||||
VLOG(3) << "Requesting element from consumer index "
|
||||
<< consumer_index.value() << ", round "
|
||||
<< round_index.value();
|
||||
}
|
||||
Status s =
|
||||
task->worker->GetElement(task->task_id, consumer_index, round_index,
|
||||
compressed, end_of_sequence);
|
||||
if (s.ok()) {
|
||||
break;
|
||||
}
|
||||
@ -530,7 +607,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// If `UpdateTaskThreads` finds that the task has been cancelled, it
|
||||
// will set end_of_sequence to `true`.
|
||||
if (task->end_of_sequence || cancelled_) {
|
||||
return Status::OK();
|
||||
end_of_sequence = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
const int64 now_micros = EnvTime::NowMicros();
|
||||
@ -559,26 +637,47 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
element.push_back(tensor);
|
||||
}
|
||||
mutex_lock l(mu_);
|
||||
result.ready = true;
|
||||
result.end_of_sequence = end_of_sequence;
|
||||
if (end_of_sequence) {
|
||||
task->end_of_sequence = true;
|
||||
finished_tasks_++;
|
||||
return Status::OK();
|
||||
}
|
||||
results_.push(std::move(element));
|
||||
task->elements_read++;
|
||||
result.element = std::move(element);
|
||||
if (enqueue_result) {
|
||||
results_.push(std::move(result));
|
||||
}
|
||||
get_next_cv_.notify_all();
|
||||
VLOG(3) << "Got an element for task id " << task->task_id;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool SpaceInBuffer() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// Reports whether we can request another element without violating
|
||||
// max_outstanding_requests.
|
||||
bool ElementSpaceAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
// When doing round-robin reads, outstanding requests pre-allocate a
|
||||
// result in `results_`, so we only need to check the size of `results_`.
|
||||
if (StrictRoundRobin()) {
|
||||
return results_.size() < max_outstanding_requests_;
|
||||
}
|
||||
// Otherwise, results aren't added to `results_` until the data has been
|
||||
// successfully retrieved. We need to count requests already added to
|
||||
// `results_` as well as in-progress requests.
|
||||
return results_.size() + outstanding_requests_ <
|
||||
max_outstanding_requests_;
|
||||
}
|
||||
|
||||
bool TaskAvailable() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
if (StrictRoundRobin()) {
|
||||
return !tasks_[next_task_index_]->in_use;
|
||||
}
|
||||
return finished_tasks_ + outstanding_requests_ < tasks_.size();
|
||||
}
|
||||
|
||||
bool StrictRoundRobin() { return dataset()->num_consumers_.has_value(); }
|
||||
|
||||
const int64 iterator_index_;
|
||||
|
||||
mutable mutex mu_;
|
||||
@ -610,7 +709,12 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
// A status to be returned from the next call to `GetNext`. This is set by
|
||||
// asynchronous threads when they encounter errors.
|
||||
Status status_ TF_GUARDED_BY(mu_) = Status::OK();
|
||||
std::queue<std::vector<Tensor>> results_ TF_GUARDED_BY(mu_);
|
||||
// A queue of results for `GetElement` requests to read from. When doing
|
||||
// strict round robin reads, the queue will contain placeholder results with
|
||||
// their `Result::ready` field false until their data has been retrieved
|
||||
// from a worker. When not doing round-robin reads, results are only added
|
||||
// to the queue after they are ready, to avoid head-of-line blocking.
|
||||
std::queue<Result> results_ TF_GUARDED_BY(mu_);
|
||||
|
||||
bool initialized_ = false;
|
||||
// Set once in Initialize().
|
||||
@ -622,11 +726,14 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||
std::unique_ptr<Thread> task_thread_manager_ TF_GUARDED_BY(mu_);
|
||||
};
|
||||
|
||||
const int op_version_;
|
||||
const int64 dataset_id_;
|
||||
const ProcessingMode processing_mode_;
|
||||
const tstring address_;
|
||||
const tstring protocol_;
|
||||
const tstring job_name_;
|
||||
const absl::optional<int64> consumer_index_;
|
||||
const absl::optional<int64> num_consumers_;
|
||||
const int64 max_outstanding_requests_;
|
||||
const int64 task_refresh_interval_ms_;
|
||||
IterationCounter* const iteration_counter_; // Owned
|
||||
@ -646,6 +753,16 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputShapes, &output_shapes_));
|
||||
auto& op_name = ctx->def().op();
|
||||
if (op_name == kDataServiceDatasetV1) {
|
||||
op_version_ = 1;
|
||||
} else if (op_name == kDataServiceDatasetV2) {
|
||||
op_version_ = 2;
|
||||
} else {
|
||||
ctx->CtxFailure(errors::FailedPrecondition(
|
||||
"Unrecognized data service dataset op name: ", op_name));
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
@ -673,6 +790,24 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
tstring job_name;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
|
||||
|
||||
absl::optional<int64> consumer_index;
|
||||
absl::optional<int64> num_consumers;
|
||||
if (op_version_ >= 2) {
|
||||
int64 consumer_index_int;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ParseScalarArgument(ctx, kConsumerIndex, &consumer_index_int));
|
||||
if (consumer_index_int >= 0) {
|
||||
consumer_index = consumer_index_int;
|
||||
}
|
||||
|
||||
int64 num_consumers_int;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ParseScalarArgument(ctx, kNumConsumers, &num_consumers_int));
|
||||
if (num_consumers_int >= 0) {
|
||||
num_consumers = num_consumers_int;
|
||||
}
|
||||
}
|
||||
|
||||
int64 max_outstanding_requests;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
|
||||
&max_outstanding_requests));
|
||||
@ -712,15 +847,17 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
|
||||
model::kAutotune));
|
||||
|
||||
*output =
|
||||
new Dataset(ctx, dataset_id, processing_mode, address, protocol, job_name,
|
||||
max_outstanding_requests, task_refresh_interval_hint_ms_,
|
||||
iteration_counter, owns_resource, iteration_counter_handle,
|
||||
output_types_, output_shapes_);
|
||||
*output = new Dataset(
|
||||
ctx, op_version_, dataset_id, processing_mode, address, protocol,
|
||||
job_name, consumer_index, num_consumers, max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms_, iteration_counter, owns_resource,
|
||||
iteration_counter_handle, output_types_, output_shapes_);
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
||||
DataServiceDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DataServiceDatasetV2").Device(DEVICE_CPU),
|
||||
DataServiceDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
|
||||
DummyResourceOp<IterationCounter>);
|
||||
|
||||
|
@ -52,6 +52,8 @@ class DataServiceDatasetOp : public DatasetOpKernel {
|
||||
static constexpr const char* const kAddress = "address";
|
||||
static constexpr const char* const kProtocol = "protocol";
|
||||
static constexpr const char* const kJobName = "job_name";
|
||||
static constexpr const char* const kConsumerIndex = "consumer_index";
|
||||
static constexpr const char* const kNumConsumers = "num_consumers";
|
||||
static constexpr const char* const kMaxOutstandingRequests =
|
||||
"max_outstanding_requests";
|
||||
static constexpr const char* const kTaskRefreshIntervalHintMs =
|
||||
@ -70,6 +72,7 @@ class DataServiceDatasetOp : public DatasetOpKernel {
|
||||
private:
|
||||
class Dataset;
|
||||
|
||||
int op_version_;
|
||||
int64 task_refresh_interval_hint_ms_;
|
||||
DataTypeVector output_types_;
|
||||
std::vector<PartialTensorShape> output_shapes_;
|
||||
|
@ -212,7 +212,9 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
autotune_(params.dataset->num_parallel_calls_ == model::kAutotune) {}
|
||||
|
||||
~Iterator() override {
|
||||
cancellation_manager_->StartCancel();
|
||||
CancelThreads(/*wait=*/true);
|
||||
input_impl_.reset();
|
||||
if (deregister_fn_) deregister_fn_();
|
||||
}
|
||||
|
||||
@ -221,11 +223,15 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
if (num_parallel_calls_->value == model::kAutotune) {
|
||||
num_parallel_calls_->value = ctx->runner_threadpool_size();
|
||||
}
|
||||
cancellation_manager_ =
|
||||
absl::make_unique<CancellationManager>(ctx->cancellation_manager());
|
||||
IteratorContext::Params params(ctx);
|
||||
params.cancellation_manager = cancellation_manager_.get();
|
||||
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
|
||||
ctx->cancellation_manager(),
|
||||
[this]() { CancelThreads(/*wait=*/false); }, &deregister_fn_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
TF_RETURN_IF_ERROR(dataset()->input_->MakeIterator(
|
||||
IteratorContext(params), this, prefix(), &input_impl_));
|
||||
return dataset()->captured_func_->Instantiate(
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
@ -640,12 +646,17 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
const bool autotune_;
|
||||
// Counts the number of outstanding calls.
|
||||
int64 num_calls_ TF_GUARDED_BY(*mu_) = 0;
|
||||
// Controls cancellation of `input_impl_`.
|
||||
// Must be ordered before `input_impl_` so that `input_impl_` is destroyed
|
||||
// first.
|
||||
std::unique_ptr<CancellationManager> cancellation_manager_;
|
||||
std::unique_ptr<InstantiatedCapturedFunction> instantiated_captured_func_;
|
||||
// Must be ordered after `cancellation_manager_` so that `input_impl_` is
|
||||
// destroyed first.
|
||||
std::unique_ptr<IteratorBase> input_impl_;
|
||||
// Buffer for storing the invocation results.
|
||||
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
|
||||
TF_GUARDED_BY(*mu_);
|
||||
|
||||
std::unique_ptr<Thread> runner_thread_ TF_GUARDED_BY(*mu_);
|
||||
std::unique_ptr<Thread> stats_thread_ TF_GUARDED_BY(*mu_);
|
||||
bool cancelled_ TF_GUARDED_BY(*mu_) = false;
|
||||
|
@ -1187,6 +1187,25 @@ REGISTER_OP("DataServiceDataset")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
// Adds `consumer_index` and `num_consumers` arguments to support round-robin
|
||||
// reads.
|
||||
REGISTER_OP("DataServiceDatasetV2")
|
||||
.Input("dataset_id: int64")
|
||||
.Input("processing_mode: string")
|
||||
.Input("address: string")
|
||||
.Input("protocol: string")
|
||||
.Input("job_name: string")
|
||||
.Input("consumer_index: int64")
|
||||
.Input("num_consumers: int64")
|
||||
.Input("max_outstanding_requests: int64")
|
||||
.Input("iteration_counter: resource")
|
||||
.Output("handle: variant")
|
||||
.Attr("task_refresh_interval_hint_ms: int = -1")
|
||||
.Attr("output_types: list(type) >= 1")
|
||||
.Attr("output_shapes: list(shape) >= 1")
|
||||
.SetIsStateful()
|
||||
.SetShapeFn(shape_inference::ScalarShape);
|
||||
|
||||
REGISTER_OP("RegisterDataset")
|
||||
.Input("dataset: variant")
|
||||
.Input("address: string")
|
||||
|
@ -283,6 +283,133 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
|
||||
results.append(elem.numpy())
|
||||
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(num_workers=[1, 3], num_consumers=[1, 2, 5])))
|
||||
def testRoundRobin(self, num_workers, num_consumers):
|
||||
cluster = self.create_cluster(num_workers=num_workers)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = ds.repeat()
|
||||
consumers = []
|
||||
for consumer_index in range(num_consumers):
|
||||
consumers.append(
|
||||
self.make_distributed_dataset(
|
||||
ds,
|
||||
cluster,
|
||||
job_name="test",
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers))
|
||||
# Use parallel interleave to read from consumers in parallel.
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(consumers)
|
||||
ds = ds.interleave(
|
||||
lambda x: x,
|
||||
cycle_length=num_consumers,
|
||||
num_parallel_calls=num_consumers)
|
||||
ds = ds.take(2 * num_elements * num_workers)
|
||||
results = self.getDatasetOutput(ds, requires_initialization=True)
|
||||
|
||||
expected = []
|
||||
round_index = 0
|
||||
while len(expected) < len(results):
|
||||
for _ in range(num_workers):
|
||||
for consumer in range(num_consumers):
|
||||
expected.append(
|
||||
(round_index * num_consumers + consumer) % num_elements)
|
||||
round_index += 1
|
||||
self.assertEqual(results, expected)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testRoundRobinBucketizing(self):
|
||||
# Tests a common use case for round robin reads. At each step, all
|
||||
# consumers should get batches with the same bucket size.
|
||||
cluster = self.create_cluster(num_workers=4)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements, output_type=dtypes.int32)
|
||||
ds = ds.shuffle(num_elements)
|
||||
low_bucket_max = 30
|
||||
mid_bucket_max = 60
|
||||
bucket_boundaries = [low_bucket_max, mid_bucket_max]
|
||||
batch_size = 10
|
||||
num_consumers = 3
|
||||
bucket_batch_sizes = [batch_size] * (len(bucket_boundaries) + 1)
|
||||
ds = ds.apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
lambda x: x,
|
||||
bucket_boundaries,
|
||||
bucket_batch_sizes,
|
||||
drop_remainder=True))
|
||||
ds = ds.apply(
|
||||
grouping.group_by_window(
|
||||
lambda x: math_ops.cast(x[1], dtypes.int64),
|
||||
lambda _, x: dataset_ops.Dataset.from_tensors(x),
|
||||
window_size=num_consumers))
|
||||
ds = ds.flat_map(lambda x: x)
|
||||
ds = ds.repeat()
|
||||
|
||||
consumers = []
|
||||
for consumer_index in range(num_consumers):
|
||||
consumers.append(
|
||||
self.make_distributed_dataset(
|
||||
ds,
|
||||
cluster,
|
||||
job_name="test",
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers))
|
||||
# Use parallel interleave to read from consumers in parallel.
|
||||
ds = dataset_ops.Dataset.from_tensor_slices(consumers)
|
||||
ds = ds.interleave(
|
||||
lambda x: x.prefetch(num_elements),
|
||||
cycle_length=num_consumers,
|
||||
num_parallel_calls=num_consumers)
|
||||
|
||||
num_rounds = 10
|
||||
get_next = self.getNext(ds, requires_initialization=True)
|
||||
results = []
|
||||
for _ in range(num_rounds):
|
||||
results.append(self.evaluate(get_next()))
|
||||
|
||||
def get_bucket(elem):
|
||||
bucket_ind = 0
|
||||
while bucket_ind < len(
|
||||
bucket_boundaries) and elem >= bucket_boundaries[bucket_ind]:
|
||||
bucket_ind += 1
|
||||
return bucket_ind
|
||||
|
||||
for i in range(0, len(results), num_consumers):
|
||||
batches = results[num_consumers * i:num_consumers * i + num_consumers]
|
||||
bucket_inds = [get_bucket(batch[0]) for batch in batches]
|
||||
for bucket_ind in bucket_inds[1:]:
|
||||
self.assertEqual(bucket_inds[0], bucket_ind)
|
||||
|
||||
@combinations.generate(test_base.v1_only_combinations())
|
||||
def testRoundRobinFiniteV1(self):
|
||||
cluster = self.create_cluster(num_workers=1)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = self.make_distributed_dataset(
|
||||
ds, cluster, job_name="test", consumer_index=0, num_consumers=1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
errors.FailedPreconditionError, "Encountered end of sequence on a "
|
||||
"round-robin read iterator"):
|
||||
self.getDatasetOutput(ds, requires_initialization=True)
|
||||
|
||||
@combinations.generate(test_base.v2_only_combinations())
|
||||
def testRoundRobinFiniteV2(self):
|
||||
cluster = self.create_cluster(num_workers=1)
|
||||
num_elements = 100
|
||||
ds = dataset_ops.Dataset.range(num_elements)
|
||||
ds = self.make_distributed_dataset(
|
||||
ds, cluster, job_name="test", consumer_index=0, num_consumers=1)
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
errors.FailedPreconditionError, "Round robin reads "
|
||||
"require that the input dataset has infinite "
|
||||
"cardinality, but the dataset has cardinality " + str(num_elements)):
|
||||
self.getDatasetOutput(ds, requires_initialization=True)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.eager_only_combinations(),
|
||||
combinations.combine(job_name=[None, "test"])))
|
||||
|
@ -205,6 +205,8 @@ class TestBase(test_base.DatasetTestBase):
|
||||
cluster,
|
||||
processing_mode="parallel_epochs",
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None):
|
||||
# pylint: disable=protected-access
|
||||
return dataset.apply(
|
||||
@ -212,6 +214,8 @@ class TestBase(test_base.DatasetTestBase):
|
||||
processing_mode,
|
||||
cluster.target,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=20))
|
||||
|
||||
|
@ -60,6 +60,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
address,
|
||||
protocol,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
"""Constructs a _DataServiceDatasetV2.
|
||||
@ -77,6 +79,17 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is
|
||||
that the dataset creates anonymous, exclusively owned jobs.
|
||||
consumer_index: (Optional.) The index of the consumer in the range from
|
||||
`0` to `num_consumers`. Must be specified alongside `num_consumers`.
|
||||
When specified, consumers will read from the job in a strict round-robin
|
||||
order, instead of the default first-come-first-served order.
|
||||
num_consumers: (Optional.) The number of consumers which will consume from
|
||||
the job. Must be specified alongside `consumer_index`. When specified,
|
||||
consumers will read from the job in a strict round-robin order, instead
|
||||
of the default first-come-first-served order. When `num_consumers` is
|
||||
specified, the dataset must have infinite cardinality to prevent a
|
||||
producer from running out of data early and causing consumers to go out
|
||||
of sync.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the
|
||||
amount of memory used, since `distribute` won't use more than
|
||||
@ -84,6 +97,13 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
task_refresh_interval_hint_ms: (Optional.) A hint for how often to query
|
||||
the dispatcher for task changes.
|
||||
"""
|
||||
if consumer_index is None != num_consumers is None:
|
||||
raise ValueError(
|
||||
"Must either set both consumer_index and num_consumers, or neither. ",
|
||||
"consumer_index: ", consumer_index, ", num_consumers: ",
|
||||
num_consumers)
|
||||
if num_consumers is not None and job_name is None:
|
||||
raise ValueError("job_name must be set when setting num_consumers")
|
||||
|
||||
if job_name is None:
|
||||
job_name = ""
|
||||
@ -91,6 +111,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
max_outstanding_requests = dataset_ops.AUTOTUNE
|
||||
if task_refresh_interval_hint_ms is None:
|
||||
task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
|
||||
if consumer_index is None:
|
||||
consumer_index = -1
|
||||
if num_consumers is None:
|
||||
num_consumers = -1
|
||||
|
||||
self._dataset_id = ops.convert_to_tensor(
|
||||
dataset_id, dtype=dtypes.int64, name="dataset_id")
|
||||
@ -102,6 +126,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
protocol, dtype=dtypes.string, name="protocol")
|
||||
self._job_name = ops.convert_to_tensor(
|
||||
job_name, dtype=dtypes.string, name="job_name")
|
||||
self._consumer_index = ops.convert_to_tensor(
|
||||
consumer_index, dtype=dtypes.int64, name="consumer_index")
|
||||
self._num_consumers = ops.convert_to_tensor(
|
||||
num_consumers, dtype=dtypes.int64, name="num_consumers")
|
||||
self._max_outstanding_requests = ops.convert_to_tensor(
|
||||
max_outstanding_requests,
|
||||
dtype=dtypes.int64,
|
||||
@ -110,17 +138,32 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
||||
# represented by scalar DT_VARIANTs.
|
||||
self._element_spec = tensor_spec.TensorSpec(shape=(), dtype=dtypes.variant)
|
||||
|
||||
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
|
||||
dataset_id=self._dataset_id,
|
||||
processing_mode=self._processing_mode,
|
||||
address=self._address,
|
||||
protocol=self._protocol,
|
||||
job_name=self._job_name,
|
||||
max_outstanding_requests=self._max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
|
||||
),
|
||||
**self._flat_structure)
|
||||
if num_consumers >= 0:
|
||||
variant_tensor = gen_experimental_dataset_ops.data_service_dataset_v2(
|
||||
dataset_id=self._dataset_id,
|
||||
processing_mode=self._processing_mode,
|
||||
address=self._address,
|
||||
protocol=self._protocol,
|
||||
job_name=self._job_name,
|
||||
consumer_index=self._consumer_index,
|
||||
num_consumers=self._num_consumers,
|
||||
max_outstanding_requests=self._max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
iteration_counter=gen_experimental_dataset_ops
|
||||
.dummy_iteration_counter(),
|
||||
**self._flat_structure)
|
||||
else:
|
||||
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
|
||||
dataset_id=self._dataset_id,
|
||||
processing_mode=self._processing_mode,
|
||||
address=self._address,
|
||||
protocol=self._protocol,
|
||||
job_name=self._job_name,
|
||||
max_outstanding_requests=self._max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||
iteration_counter=gen_experimental_dataset_ops
|
||||
.dummy_iteration_counter(),
|
||||
**self._flat_structure)
|
||||
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
||||
|
||||
@property
|
||||
@ -133,7 +176,8 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
||||
|
||||
@functools.wraps(_DataServiceDatasetV2.__init__)
|
||||
def __init__(self, dataset_id, processing_mode, address, protocol, job_name,
|
||||
max_outstanding_requests, task_refresh_interval_hint_ms):
|
||||
consumer_index, num_consumers, max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms):
|
||||
|
||||
self._wrapped = _DataServiceDatasetV2(
|
||||
dataset_id=dataset_id,
|
||||
@ -141,6 +185,8 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
|
||||
@ -184,6 +230,8 @@ def _from_dataset_id(processing_mode,
|
||||
dataset_id,
|
||||
element_spec,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
"""Creates a dataset which reads data from the tf.data service.
|
||||
@ -209,6 +257,17 @@ def _from_dataset_id(processing_mode,
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is that
|
||||
the dataset creates anonymous, exclusively owned jobs.
|
||||
consumer_index: (Optional.) The index of the consumer in the range from
|
||||
`0` to `num_consumers`. Must be specified alongside `num_consumers`.
|
||||
When specified, consumers will read from the job in a strict round-robin
|
||||
order, instead of the default first-come-first-served order.
|
||||
num_consumers: (Optional.) The number of consumers which will consume from
|
||||
the job. Must be specified alongside `consumer_index`. When specified,
|
||||
consumers will read from the job in a strict round-robin order, instead
|
||||
of the default first-come-first-served order. When `num_consumers` is
|
||||
specified, the dataset must have infinite cardinality to prevent a
|
||||
producer from running out of data early and causing consumers to go out of
|
||||
sync.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
@ -236,6 +295,8 @@ def _from_dataset_id(processing_mode,
|
||||
address=address,
|
||||
protocol=protocol,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
dataset = dataset.map(
|
||||
@ -253,6 +314,8 @@ def _from_dataset_id(processing_mode,
|
||||
def _distribute(processing_mode,
|
||||
service,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None,
|
||||
task_refresh_interval_hint_ms=None):
|
||||
"""A transformation that moves dataset processing to the tf.data service.
|
||||
@ -272,6 +335,17 @@ def _distribute(processing_mode,
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is that
|
||||
the dataset creates anonymous, exclusively owned jobs.
|
||||
consumer_index: (Optional.) The index of the consumer in the range from
|
||||
`0` to `num_consumers`. Must be specified alongside `num_consumers`.
|
||||
When specified, consumers will read from the job in a strict round-robin
|
||||
order, instead of the default first-come-first-served order.
|
||||
num_consumers: (Optional.) The number of consumers which will consume from
|
||||
the job. Must be specified alongside `consumer_index`. When specified,
|
||||
consumers will read from the job in a strict round-robin order, instead
|
||||
of the default first-come-first-served order. When `num_consumers` is
|
||||
specified, the dataset must have infinite cardinality to prevent a
|
||||
producer from running out of data early and causing consumers to go out of
|
||||
sync.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
@ -292,6 +366,8 @@ def _distribute(processing_mode,
|
||||
dataset_id,
|
||||
dataset.element_spec,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests,
|
||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||
|
||||
@ -302,6 +378,8 @@ def _distribute(processing_mode,
|
||||
def distribute(processing_mode,
|
||||
service,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None):
|
||||
"""A transformation that moves dataset processing to the tf.data service.
|
||||
|
||||
@ -440,6 +518,25 @@ def distribute(processing_mode,
|
||||
from `job_name="job"`, it will immediately receive end of input, without
|
||||
getting any data.
|
||||
|
||||
**Round Robin data consumption**
|
||||
|
||||
By default, when multiple consumers read from the same job, they receive data
|
||||
on a first-come first-served basis. In some use cases, it works better to use
|
||||
a strict round-robin order. For example, the tf.data service can be used to
|
||||
coordinate example sizes across a cluster during sychronous training, so that
|
||||
during each step all replicas train on similar-sized elements. To achieve
|
||||
this, define a dataset which generates rounds of `num_consumers` consecutive
|
||||
similar-sized batches, then enable round-robin reads by setting
|
||||
`consumer_index` and `num_consumers`.
|
||||
|
||||
Consumers read data by cycling through all workers, reading one element from
|
||||
each. First, each consumer will read an element from the first worker, then
|
||||
each consumer will read an element from the second worker, and so on.
|
||||
|
||||
NOTE: To keep consumers in sync, round robin data consumption requires that
|
||||
the dataset have infinite cardinality. You can get this by adding `.repeat()`
|
||||
at the end of the dataset definition.
|
||||
|
||||
**Keras and Distribution Strategies**
|
||||
|
||||
The dataset produced by the `distribute` transformation can be passed to
|
||||
@ -468,6 +565,17 @@ def distribute(processing_mode,
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is that
|
||||
the dataset creates anonymous, exclusively owned jobs.
|
||||
consumer_index: (Optional.) The index of the consumer in the range from
|
||||
`0` to `num_consumers`. Must be specified alongside `num_consumers`.
|
||||
When specified, consumers will read from the job in a strict round-robin
|
||||
order, instead of the default first-come-first-served order.
|
||||
num_consumers: (Optional.) The number of consumers which will consume from
|
||||
the job. Must be specified alongside `consumer_index`. When specified,
|
||||
consumers will read from the job in a strict round-robin order, instead
|
||||
of the default first-come-first-served order. When `num_consumers` is
|
||||
specified, the dataset must have infinite cardinality to prevent a
|
||||
producer from running out of data early and causing consumers to go out of
|
||||
sync.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
@ -480,6 +588,8 @@ def distribute(processing_mode,
|
||||
processing_mode=processing_mode,
|
||||
service=service,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests)
|
||||
|
||||
|
||||
@ -553,6 +663,8 @@ def from_dataset_id(processing_mode,
|
||||
dataset_id,
|
||||
element_spec=None,
|
||||
job_name=None,
|
||||
consumer_index=None,
|
||||
num_consumers=None,
|
||||
max_outstanding_requests=None):
|
||||
"""Creates a dataset which reads data from the tf.data service.
|
||||
|
||||
@ -612,6 +724,17 @@ def from_dataset_id(processing_mode,
|
||||
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||
for multiple datasets to share the same job. The default behavior is that
|
||||
the dataset creates anonymous, exclusively owned jobs.
|
||||
consumer_index: (Optional.) The index of the consumer in the range from
|
||||
`0` to `num_consumers`. Must be specified alongside `num_consumers`.
|
||||
When specified, consumers will read from the job in a strict round-robin
|
||||
order, instead of the default first-come-first-served order.
|
||||
num_consumers: (Optional.) The number of consumers which will consume from
|
||||
the job. Must be specified alongside `consumer_index`. When specified,
|
||||
consumers will read from the job in a strict round-robin order, instead
|
||||
of the default first-come-first-served order. When `num_consumers` is
|
||||
specified, the dataset must have infinite cardinality to prevent a
|
||||
producer from running out of data early and causing consumers to go out of
|
||||
sync.
|
||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||
requested at the same time. You can use this option to control the amount
|
||||
of memory used, since `distribute` won't use more than `element_size` *
|
||||
@ -626,4 +749,6 @@ def from_dataset_id(processing_mode,
|
||||
dataset_id=dataset_id,
|
||||
element_spec=element_spec,
|
||||
job_name=job_name,
|
||||
consumer_index=consumer_index,
|
||||
num_consumers=num_consumers,
|
||||
max_outstanding_requests=max_outstanding_requests)
|
||||
|
@ -10,11 +10,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "distribute"
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_dataset_id"
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'dataset_id\', \'element_spec\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'dataset_id\', \'element_spec\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "register_dataset"
|
||||
|
@ -1008,6 +1008,10 @@ tf_module {
|
||||
name: "DataServiceDataset"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDatasetV2"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DatasetCardinality"
|
||||
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
@ -18,11 +18,11 @@ tf_module {
|
||||
}
|
||||
member_method {
|
||||
name: "distribute"
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "from_dataset_id"
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'dataset_id\', \'element_spec\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
|
||||
argspec: "args=[\'processing_mode\', \'service\', \'dataset_id\', \'element_spec\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "register_dataset"
|
||||
|
@ -1008,6 +1008,10 @@ tf_module {
|
||||
name: "DataServiceDataset"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DataServiceDatasetV2"
|
||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'consumer_index\', \'num_consumers\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||
}
|
||||
member_method {
|
||||
name: "DatasetCardinality"
|
||||
argspec: "args=[\'input_dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||
|
Loading…
Reference in New Issue
Block a user