[tf.data service] Add support for shared job names.
Shared job names give users a way to share tf.data service job output across multiple datasets. PiperOrigin-RevId: 310037762 Change-Id: I5a8f9130806a361ada46b3f0b62ea71b0f9b2e8e
This commit is contained in:
parent
48baba71cf
commit
22546b562d
@ -0,0 +1,4 @@
|
|||||||
|
op {
|
||||||
|
graph_op_name: "DummyIterationCounter"
|
||||||
|
visibility: HIDDEN
|
||||||
|
}
|
@ -56,6 +56,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":common_proto_cc",
|
":common_proto_cc",
|
||||||
":credentials_factory",
|
":credentials_factory",
|
||||||
|
":data_service",
|
||||||
":grpc_util",
|
":grpc_util",
|
||||||
":master_proto_cc",
|
":master_proto_cc",
|
||||||
":worker_cc_grpc_proto",
|
":worker_cc_grpc_proto",
|
||||||
|
@ -31,7 +31,7 @@ constexpr const char kParallelEpochs[] = "parallel_epochs";
|
|||||||
constexpr const char kOneEpoch[] = "one_epoch";
|
constexpr const char kOneEpoch[] = "one_epoch";
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) {
|
Status ParseProcessingMode(const std::string& s, ProcessingMode* mode) {
|
||||||
if (s == kParallelEpochs) {
|
if (s == kParallelEpochs) {
|
||||||
*mode = ProcessingMode::PARALLEL_EPOCHS;
|
*mode = ProcessingMode::PARALLEL_EPOCHS;
|
||||||
} else if (s == kOneEpoch) {
|
} else if (s == kOneEpoch) {
|
||||||
@ -54,6 +54,21 @@ std::string ProcessingModeToString(ProcessingMode mode) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
||||||
|
int64* dataset_id) {
|
||||||
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
|
GetOrRegisterDatasetRequest req;
|
||||||
|
*req.mutable_dataset()->mutable_graph() = dataset;
|
||||||
|
GetOrRegisterDatasetResponse resp;
|
||||||
|
grpc::ClientContext client_ctx;
|
||||||
|
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
|
||||||
|
if (!status.ok()) {
|
||||||
|
return grpc_util::WrapError("Failed to register dataset", status);
|
||||||
|
}
|
||||||
|
*dataset_id = resp.dataset_id();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||||
ProcessingMode processing_mode,
|
ProcessingMode processing_mode,
|
||||||
int64* job_id) {
|
int64* job_id) {
|
||||||
@ -73,18 +88,27 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::RegisterDataset(GraphDef dataset,
|
Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id,
|
||||||
int64* dataset_id) {
|
ProcessingMode processing_mode,
|
||||||
|
const std::string& job_name,
|
||||||
|
int job_name_index,
|
||||||
|
int64* job_id) {
|
||||||
TF_RETURN_IF_ERROR(EnsureInitialized());
|
TF_RETURN_IF_ERROR(EnsureInitialized());
|
||||||
GetOrRegisterDatasetRequest req;
|
GetOrCreateJobRequest req;
|
||||||
*req.mutable_dataset()->mutable_graph() = dataset;
|
req.set_dataset_id(dataset_id);
|
||||||
GetOrRegisterDatasetResponse resp;
|
req.set_processing_mode(ProcessingModeDef(processing_mode));
|
||||||
|
req.set_job_name(job_name);
|
||||||
|
req.set_job_name_index(job_name_index);
|
||||||
|
GetOrCreateJobResponse resp;
|
||||||
grpc::ClientContext client_ctx;
|
grpc::ClientContext client_ctx;
|
||||||
grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp);
|
grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp);
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
return grpc_util::WrapError("Failed to register dataset", status);
|
return grpc_util::WrapError(
|
||||||
|
absl::StrCat("Failed to get or create job for dataset with id ",
|
||||||
|
dataset_id),
|
||||||
|
status);
|
||||||
}
|
}
|
||||||
*dataset_id = resp.dataset_id();
|
*job_id = resp.job_id();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +172,7 @@ Status DataServiceWorkerClient::EnsureInitialized() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDataServiceMasterClient(
|
Status CreateDataServiceMasterClient(
|
||||||
absl::string_view address, absl::string_view protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceMasterClient>* out) {
|
std::unique_ptr<DataServiceMasterClient>* out) {
|
||||||
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
auto client = absl::make_unique<DataServiceMasterClient>(address, protocol);
|
||||||
TF_RETURN_IF_ERROR(client->Initialize());
|
TF_RETURN_IF_ERROR(client->Initialize());
|
||||||
@ -157,7 +181,7 @@ Status CreateDataServiceMasterClient(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status CreateDataServiceWorkerClient(
|
Status CreateDataServiceWorkerClient(
|
||||||
absl::string_view address, absl::string_view protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceWorkerClient>* out) {
|
std::unique_ptr<DataServiceWorkerClient>* out) {
|
||||||
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
|
auto client = absl::make_unique<DataServiceWorkerClient>(address, protocol);
|
||||||
TF_RETURN_IF_ERROR(client->Initialize());
|
TF_RETURN_IF_ERROR(client->Initialize());
|
||||||
|
@ -35,7 +35,7 @@ enum class ProcessingMode : int64 {
|
|||||||
|
|
||||||
// Parses a string representing a processing mode and stores the result in
|
// Parses a string representing a processing mode and stores the result in
|
||||||
// *mode. Returns an InvalidArgument status if the string is not recognized.
|
// *mode. Returns an InvalidArgument status if the string is not recognized.
|
||||||
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode);
|
Status ParseProcessingMode(const std::string& s, ProcessingMode* mode);
|
||||||
|
|
||||||
// Converts a processing mode to its corresponding string.
|
// Converts a processing mode to its corresponding string.
|
||||||
std::string ProcessingModeToString(ProcessingMode mode);
|
std::string ProcessingModeToString(ProcessingMode mode);
|
||||||
@ -45,7 +45,7 @@ std::string ProcessingModeToString(ProcessingMode mode);
|
|||||||
// threads.
|
// threads.
|
||||||
class DataServiceClientBase {
|
class DataServiceClientBase {
|
||||||
public:
|
public:
|
||||||
DataServiceClientBase(absl::string_view address, absl::string_view protocol)
|
DataServiceClientBase(const std::string& address, const std::string& protocol)
|
||||||
: address_(address), protocol_(protocol) {}
|
: address_(address), protocol_(protocol) {}
|
||||||
|
|
||||||
virtual ~DataServiceClientBase() = default;
|
virtual ~DataServiceClientBase() = default;
|
||||||
@ -70,7 +70,8 @@ class DataServiceClientBase {
|
|||||||
// Client for communicating with the tf.data service master.
|
// Client for communicating with the tf.data service master.
|
||||||
class DataServiceMasterClient : public DataServiceClientBase {
|
class DataServiceMasterClient : public DataServiceClientBase {
|
||||||
public:
|
public:
|
||||||
DataServiceMasterClient(absl::string_view address, absl::string_view protocol)
|
DataServiceMasterClient(const std::string& address,
|
||||||
|
const std::string& protocol)
|
||||||
: DataServiceClientBase(address, protocol) {}
|
: DataServiceClientBase(address, protocol) {}
|
||||||
|
|
||||||
// Registers a dataset with the tf.data service, and stores the generated
|
// Registers a dataset with the tf.data service, and stores the generated
|
||||||
@ -82,6 +83,13 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
|||||||
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||||
int64* job_id);
|
int64* job_id);
|
||||||
|
|
||||||
|
// Gets the job id for the job represented by the tuple
|
||||||
|
// (job_name, job_name_index), and stores the id in *job_id. If the
|
||||||
|
// job doesn't exist yet, it will be created.
|
||||||
|
Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||||
|
const std::string& job_name, int job_name_index,
|
||||||
|
int64* job_id);
|
||||||
|
|
||||||
// Queries the master for the tasks associated with the specified job.
|
// Queries the master for the tasks associated with the specified job.
|
||||||
// The tasks will be stored in *tasks, and whether the job is finished will
|
// The tasks will be stored in *tasks, and whether the job is finished will
|
||||||
// be stored in `*job_finished`.
|
// be stored in `*job_finished`.
|
||||||
@ -98,7 +106,8 @@ class DataServiceMasterClient : public DataServiceClientBase {
|
|||||||
// Client for communicating with the tf.data service worker.
|
// Client for communicating with the tf.data service worker.
|
||||||
class DataServiceWorkerClient : public DataServiceClientBase {
|
class DataServiceWorkerClient : public DataServiceClientBase {
|
||||||
public:
|
public:
|
||||||
DataServiceWorkerClient(absl::string_view address, absl::string_view protocol)
|
DataServiceWorkerClient(const std::string& address,
|
||||||
|
const std::string& protocol)
|
||||||
: DataServiceClientBase(address, protocol) {}
|
: DataServiceClientBase(address, protocol) {}
|
||||||
|
|
||||||
// Fetches the next element for the specified task_id. The element's
|
// Fetches the next element for the specified task_id. The element's
|
||||||
@ -116,12 +125,12 @@ class DataServiceWorkerClient : public DataServiceClientBase {
|
|||||||
|
|
||||||
// Creates and initializes a new tf.data service master client.
|
// Creates and initializes a new tf.data service master client.
|
||||||
Status CreateDataServiceMasterClient(
|
Status CreateDataServiceMasterClient(
|
||||||
absl::string_view address, absl::string_view protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceMasterClient>* out);
|
std::unique_ptr<DataServiceMasterClient>* out);
|
||||||
|
|
||||||
// Creates and initializes a new tf.data service worker client.
|
// Creates and initializes a new tf.data service worker client.
|
||||||
Status CreateDataServiceWorkerClient(
|
Status CreateDataServiceWorkerClient(
|
||||||
absl::string_view address, absl::string_view protocol,
|
const std::string& address, const std::string& protocol,
|
||||||
std::unique_ptr<DataServiceWorkerClient>* out);
|
std::unique_ptr<DataServiceWorkerClient>* out);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
|
@ -42,6 +42,7 @@ HANDLER(RegisterWorker);
|
|||||||
HANDLER(WorkerUpdate);
|
HANDLER(WorkerUpdate);
|
||||||
HANDLER(GetOrRegisterDataset);
|
HANDLER(GetOrRegisterDataset);
|
||||||
HANDLER(CreateJob);
|
HANDLER(CreateJob);
|
||||||
|
HANDLER(GetOrCreateJob);
|
||||||
HANDLER(GetTasks);
|
HANDLER(GetTasks);
|
||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
|
@ -46,6 +46,7 @@ class GrpcMasterImpl : public MasterService::Service {
|
|||||||
HANDLER(WorkerUpdate);
|
HANDLER(WorkerUpdate);
|
||||||
HANDLER(GetOrRegisterDataset);
|
HANDLER(GetOrRegisterDataset);
|
||||||
HANDLER(CreateJob);
|
HANDLER(CreateJob);
|
||||||
|
HANDLER(GetOrCreateJob);
|
||||||
HANDLER(GetTasks);
|
HANDLER(GetTasks);
|
||||||
#undef HANDLER
|
#undef HANDLER
|
||||||
|
|
||||||
|
@ -56,7 +56,24 @@ message CreateJobRequest {
|
|||||||
}
|
}
|
||||||
|
|
||||||
message CreateJobResponse {
|
message CreateJobResponse {
|
||||||
// An id for the begun job.
|
// An id for the created job.
|
||||||
|
int64 job_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetOrCreateJobRequest {
|
||||||
|
// The id of the dataset to create a job for.
|
||||||
|
int64 dataset_id = 1;
|
||||||
|
// A mode controlling how the tf.data service produces data for the job.
|
||||||
|
ProcessingModeDef processing_mode = 2;
|
||||||
|
// A name for the job.
|
||||||
|
string job_name = 3;
|
||||||
|
// An index for the job. Multiple jobs can be created for the same name, if
|
||||||
|
// they have different indices.
|
||||||
|
int64 job_name_index = 4;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetOrCreateJobResponse {
|
||||||
|
// The id of the (potentially newly created) job.
|
||||||
int64 job_id = 1;
|
int64 job_id = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,6 +113,9 @@ service MasterService {
|
|||||||
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
|
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
|
||||||
returns (GetOrRegisterDatasetResponse);
|
returns (GetOrRegisterDatasetResponse);
|
||||||
|
|
||||||
|
// Gets a job if it already exists, otherwise creates it.
|
||||||
|
rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse);
|
||||||
|
|
||||||
// Creates a job for reading from the tf.data service.
|
// Creates a job for reading from the tf.data service.
|
||||||
rpc CreateJob(CreateJobRequest) returns (CreateJobResponse);
|
rpc CreateJob(CreateJobRequest) returns (CreateJobResponse);
|
||||||
|
|
||||||
|
@ -25,6 +25,7 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||||
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
#include "tensorflow/core/data/service/grpc_util.h"
|
#include "tensorflow/core/data/service/grpc_util.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
#include "tensorflow/core/data/service/master.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
@ -65,17 +66,17 @@ Status DataServiceMasterImpl::RegisterWorker(
|
|||||||
|
|
||||||
// Allocate tasks to the worker.
|
// Allocate tasks to the worker.
|
||||||
for (auto& entry : jobs_) {
|
for (auto& entry : jobs_) {
|
||||||
Job& job = entry.second;
|
std::shared_ptr<Job> job = entry.second;
|
||||||
if (job.finished()) {
|
if (job->finished()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
int64 task_id = CreateTask(&job, request->worker_address());
|
int64 task_id = CreateTask(job.get(), request->worker_address());
|
||||||
|
|
||||||
TaskDef* task_def = response->add_tasks();
|
TaskDef* task_def = response->add_tasks();
|
||||||
*task_def->mutable_dataset() =
|
*task_def->mutable_dataset() =
|
||||||
datasets_by_id_[job.dataset_id()]->dataset_def();
|
datasets_by_id_[job->dataset_id()]->dataset_def();
|
||||||
task_def->set_dataset_id(job.dataset_id());
|
task_def->set_dataset_id(job->dataset_id());
|
||||||
task_def->set_job_id(job.job_id());
|
task_def->set_job_id(job->job_id());
|
||||||
task_def->set_task_id(task_id);
|
task_def->set_task_id(task_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -96,7 +97,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request,
|
|||||||
if (update.completed()) {
|
if (update.completed()) {
|
||||||
int64 job_id = tasks_.at(task_id).job_id();
|
int64 job_id = tasks_.at(task_id).job_id();
|
||||||
DCHECK(jobs_.contains(job_id));
|
DCHECK(jobs_.contains(job_id));
|
||||||
jobs_.at(job_id).task_finished(task_id);
|
jobs_.at(job_id)->task_finished(task_id);
|
||||||
VLOG(3) << "Task " << task_id << " from job " << job_id << " completed";
|
VLOG(3) << "Task " << task_id << " from job " << job_id << " completed";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -135,49 +136,109 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint,
|
|||||||
DCHECK(!datasets_by_id_.contains(dataset_id));
|
DCHECK(!datasets_by_id_.contains(dataset_id));
|
||||||
datasets_by_id_[dataset_id] = new_dataset;
|
datasets_by_id_[dataset_id] = new_dataset;
|
||||||
DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
|
DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
|
||||||
datasets_by_fingerprint_[dataset_id] = new_dataset;
|
datasets_by_fingerprint_[fingerprint] = new_dataset;
|
||||||
return dataset_id;
|
return dataset_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request,
|
||||||
CreateJobResponse* response) {
|
CreateJobResponse* response) {
|
||||||
VLOG(3) << "Received begin job request for dataset id "
|
VLOG(3) << "Received create job request for dataset id "
|
||||||
<< request->dataset_id();
|
<< request->dataset_id();
|
||||||
switch (request->processing_mode()) {
|
ProcessingMode processing_mode = ProcessingMode(request->processing_mode());
|
||||||
case PARALLEL_EPOCHS:
|
mutex_lock l(mu_);
|
||||||
|
int64 job_id;
|
||||||
|
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode,
|
||||||
|
absl::optional<std::string>(), &job_id));
|
||||||
|
response->set_job_id(job_id);
|
||||||
|
|
||||||
|
VLOG(3) << "Creating job " << job_id << " for dataset "
|
||||||
|
<< request->dataset_id();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DataServiceMasterImpl::GetOrCreateJob(
|
||||||
|
const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) {
|
||||||
|
VLOG(3) << "Received get or create job request for dataset id "
|
||||||
|
<< request->dataset_id() << " with name " << request->job_name()
|
||||||
|
<< " and index " << request->job_name_index();
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
NamedJobKey key(request->job_name(), request->job_name_index());
|
||||||
|
ProcessingMode requested_processing_mode =
|
||||||
|
ProcessingMode(request->processing_mode());
|
||||||
|
std::shared_ptr<Job>* job = gtl::FindOrNull(named_jobs_, key);
|
||||||
|
if (job != nullptr) {
|
||||||
|
TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode,
|
||||||
|
request->dataset_id()));
|
||||||
|
response->set_job_id((*job)->job_id());
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
int64 job_id;
|
||||||
|
TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), requested_processing_mode,
|
||||||
|
request->job_name(), &job_id));
|
||||||
|
named_jobs_[key] = jobs_[job_id];
|
||||||
|
response->set_job_id(job_id);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validates that the job matches the given processing_mode and dataset_id.
|
||||||
|
Status DataServiceMasterImpl::ValidateMatchingJob(
|
||||||
|
const Job& job, ProcessingMode processing_mode, int64 dataset_id) {
|
||||||
|
DCHECK(job.name().has_value());
|
||||||
|
std::string job_name = job.name().value();
|
||||||
|
if (job.processing_mode() != processing_mode) {
|
||||||
|
std::string requested = ProcessingModeToString(processing_mode);
|
||||||
|
std::string actual = ProcessingModeToString(job.processing_mode());
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Found a job with name ", job_name, ", but the processing mode <",
|
||||||
|
actual, "> doesn't match the requested processing mode <", requested,
|
||||||
|
">.");
|
||||||
|
}
|
||||||
|
if (job.dataset_id() != dataset_id) {
|
||||||
|
return errors::FailedPrecondition(
|
||||||
|
"Found a job with name ", job_name, ", but the dataset id <",
|
||||||
|
job.dataset_id(), "> doesn't match the requested dataset id <",
|
||||||
|
dataset_id, ">.");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status DataServiceMasterImpl::CreateJob(int64 dataset_id,
|
||||||
|
ProcessingMode processing_mode,
|
||||||
|
absl::optional<std::string> job_name,
|
||||||
|
int64* out_job_id)
|
||||||
|
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||||
|
switch (processing_mode) {
|
||||||
|
case ProcessingMode::PARALLEL_EPOCHS:
|
||||||
break;
|
break;
|
||||||
case ONE_EPOCH:
|
case ProcessingMode::ONE_EPOCH:
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented(
|
||||||
"CreateJob only supports the PARALLEL_EPOCHS job mode. "
|
"CreateJob only supports the PARALLEL_EPOCHS job mode. "
|
||||||
"ONE_EPOCH is not currently supported.");
|
"ONE_EPOCH is not currently supported.");
|
||||||
default:
|
default:
|
||||||
return errors::Unimplemented(
|
return errors::Unimplemented("ProcessingMode ",
|
||||||
"ProcessingMode ", request->processing_mode(), " not recognized");
|
ProcessingModeToString(processing_mode),
|
||||||
|
" not recognized");
|
||||||
}
|
}
|
||||||
mutex_lock l(mu_);
|
if (!datasets_by_id_.contains(dataset_id)) {
|
||||||
if (!datasets_by_id_.contains(request->dataset_id())) {
|
return errors::NotFound("Dataset id: <", dataset_id, "> not found.");
|
||||||
return errors::NotFound("CreateJob failed. Dataset id: <",
|
|
||||||
request->dataset_id(), "> not found.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 job_id = next_job_id_++;
|
int64 job_id = next_job_id_++;
|
||||||
DCHECK(!jobs_.contains(job_id));
|
DCHECK(!jobs_.contains(job_id));
|
||||||
auto result =
|
auto job =
|
||||||
jobs_.emplace(std::piecewise_construct, std::forward_as_tuple(job_id),
|
std::make_shared<Job>(job_id, dataset_id, processing_mode, job_name);
|
||||||
std::forward_as_tuple(job_id, request->dataset_id()));
|
jobs_[job_id] = job;
|
||||||
DCHECK(result.second);
|
|
||||||
Job& job = result.first->second;
|
|
||||||
response->set_job_id(job_id);
|
|
||||||
|
|
||||||
for (auto& worker : workers_) {
|
for (auto& worker : workers_) {
|
||||||
int64 task_id = CreateTask(&job, worker.address());
|
int64 task_id = CreateTask(job.get(), worker.address());
|
||||||
|
|
||||||
// TODO(aaudibert): perform these calls asynchronously.
|
// TODO(aaudibert): perform these calls asynchronously.
|
||||||
|
// TODO(aaudibert): clean up in case some calls succeed, but later calls
|
||||||
|
// fail
|
||||||
TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
|
TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker));
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(3) << "Beginning job " << job_id << " for dataset "
|
*out_job_id = job_id;
|
||||||
<< request->dataset_id();
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,8 +294,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
|||||||
return errors::NotFound("GetTasks failed. Job id <", request->job_id(),
|
return errors::NotFound("GetTasks failed. Job id <", request->job_id(),
|
||||||
"> not found.");
|
"> not found.");
|
||||||
}
|
}
|
||||||
Job& job = it->second;
|
std::shared_ptr<Job> job = it->second;
|
||||||
for (const auto& task_id : job.task_ids()) {
|
for (const auto& task_id : job->task_ids()) {
|
||||||
auto task_iter = tasks_.find(task_id);
|
auto task_iter = tasks_.find(task_id);
|
||||||
DCHECK(task_iter != tasks_.end());
|
DCHECK(task_iter != tasks_.end());
|
||||||
Task& task = task_iter->second;
|
Task& task = task_iter->second;
|
||||||
@ -242,7 +303,7 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request,
|
|||||||
task_info->set_worker_address(task.worker_address());
|
task_info->set_worker_address(task.worker_address());
|
||||||
task_info->set_id(task.task_id());
|
task_info->set_id(task.task_id());
|
||||||
}
|
}
|
||||||
response->set_job_finished(job.finished());
|
response->set_job_finished(job->finished());
|
||||||
VLOG(3) << "Found " << response->task_info_size() << " tasks for job id "
|
VLOG(3) << "Found " << response->task_info_size() << " tasks for job id "
|
||||||
<< request->job_id();
|
<< request->job_id();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "absl/container/flat_hash_map.h"
|
#include "absl/container/flat_hash_map.h"
|
||||||
#include "tensorflow/core/data/service/common.pb.h"
|
#include "tensorflow/core/data/service/common.pb.h"
|
||||||
|
#include "tensorflow/core/data/service/data_service.h"
|
||||||
#include "tensorflow/core/data/service/master.pb.h"
|
#include "tensorflow/core/data/service/master.pb.h"
|
||||||
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
#include "tensorflow/core/data/service/worker.grpc.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
@ -56,6 +57,8 @@ class DataServiceMasterImpl {
|
|||||||
GetOrRegisterDatasetResponse* response);
|
GetOrRegisterDatasetResponse* response);
|
||||||
Status CreateJob(const CreateJobRequest* request,
|
Status CreateJob(const CreateJobRequest* request,
|
||||||
CreateJobResponse* response);
|
CreateJobResponse* response);
|
||||||
|
Status GetOrCreateJob(const GetOrCreateJobRequest* request,
|
||||||
|
GetOrCreateJobResponse* response);
|
||||||
Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response);
|
Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -100,11 +103,17 @@ class DataServiceMasterImpl {
|
|||||||
|
|
||||||
class Job {
|
class Job {
|
||||||
public:
|
public:
|
||||||
Job(int64 job_id, int64 dataset_id)
|
Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode,
|
||||||
: job_id_(job_id), dataset_id_(dataset_id) {}
|
absl::optional<absl::string_view> job_name)
|
||||||
|
: job_id_(job_id),
|
||||||
|
dataset_id_(dataset_id),
|
||||||
|
processing_mode_(processing_mode),
|
||||||
|
job_name_(job_name) {}
|
||||||
|
|
||||||
int64 job_id() const { return job_id_; }
|
int64 job_id() const { return job_id_; }
|
||||||
int64 dataset_id() const { return dataset_id_; }
|
int64 dataset_id() const { return dataset_id_; }
|
||||||
|
ProcessingMode processing_mode() const { return processing_mode_; }
|
||||||
|
absl::optional<std::string> name() const { return job_name_; }
|
||||||
const std::vector<int64>& task_ids() const { return task_ids_; }
|
const std::vector<int64>& task_ids() const { return task_ids_; }
|
||||||
void add_task_id(int64 task_id) { task_ids_.push_back(task_id); }
|
void add_task_id(int64 task_id) { task_ids_.push_back(task_id); }
|
||||||
void task_finished(int64 task_id) {
|
void task_finished(int64 task_id) {
|
||||||
@ -118,11 +127,32 @@ class DataServiceMasterImpl {
|
|||||||
private:
|
private:
|
||||||
const int64 job_id_;
|
const int64 job_id_;
|
||||||
const int64 dataset_id_;
|
const int64 dataset_id_;
|
||||||
|
const ProcessingMode processing_mode_;
|
||||||
|
const absl::optional<std::string> job_name_;
|
||||||
std::vector<int64> task_ids_;
|
std::vector<int64> task_ids_;
|
||||||
std::vector<int64> finished_tasks_;
|
std::vector<int64> finished_tasks_;
|
||||||
bool finished_ = false;
|
bool finished_ = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class NamedJobKey {
|
||||||
|
public:
|
||||||
|
NamedJobKey(absl::string_view name, int64 index)
|
||||||
|
: name_(name), index_(index) {}
|
||||||
|
|
||||||
|
friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) {
|
||||||
|
return lhs.name_ == rhs.name_ && lhs.index_ == rhs.index_;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename H>
|
||||||
|
friend H AbslHashValue(H h, const NamedJobKey& k) {
|
||||||
|
return H::combine(std::move(h), k.name_, k.index_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string name_;
|
||||||
|
const int64 index_;
|
||||||
|
};
|
||||||
|
|
||||||
class Task {
|
class Task {
|
||||||
public:
|
public:
|
||||||
Task(int64 task_id, int64 job_id, int64 dataset_id,
|
Task(int64 task_id, int64 job_id, int64 dataset_id,
|
||||||
@ -150,9 +180,15 @@ class DataServiceMasterImpl {
|
|||||||
Status EnsureWorkerStubInitialized(Worker* worker);
|
Status EnsureWorkerStubInitialized(Worker* worker);
|
||||||
// Instructs a worker to begin processing a task.
|
// Instructs a worker to begin processing a task.
|
||||||
Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
|
Status AllocateTaskToWorker(const Task& task_id, Worker* worker);
|
||||||
|
// Creates a job and stores its job_id in `*job_id`.
|
||||||
|
Status CreateJob(int64 dataset_id, ProcessingMode processing_mode,
|
||||||
|
absl::optional<std::string> job_name, int64* out_job_id);
|
||||||
// Creates a new task for a job, returning the new task's id.
|
// Creates a new task for a job, returning the new task's id.
|
||||||
int64 CreateTask(Job* job, const std::string& worker_address);
|
int64 CreateTask(Job* job, const std::string& worker_address);
|
||||||
|
// Validates that an existing job matches the given processing_mode and
|
||||||
|
// dataset_id, returning an error status describing any difference.
|
||||||
|
Status ValidateMatchingJob(const Job& job, ProcessingMode processing_mode,
|
||||||
|
int64 dataset_id);
|
||||||
// Protocol to use for communicating with workers.
|
// Protocol to use for communicating with workers.
|
||||||
const std::string protocol_;
|
const std::string protocol_;
|
||||||
|
|
||||||
@ -172,9 +208,13 @@ class DataServiceMasterImpl {
|
|||||||
absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> datasets_by_fingerprint_
|
absl::flat_hash_map<uint64, std::shared_ptr<Dataset>> datasets_by_fingerprint_
|
||||||
TF_GUARDED_BY(mu_);
|
TF_GUARDED_BY(mu_);
|
||||||
// Information about jobs, keyed by job ids.
|
// Information about jobs, keyed by job ids.
|
||||||
absl::flat_hash_map<int64, Job> jobs_ TF_GUARDED_BY(mu_);
|
absl::flat_hash_map<int64, std::shared_ptr<Job>> jobs_ TF_GUARDED_BY(mu_);
|
||||||
// Information about tasks, keyed by task ids.
|
// Information about tasks, keyed by task ids.
|
||||||
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
absl::flat_hash_map<int64, Task> tasks_ TF_GUARDED_BY(mu_);
|
||||||
|
// Named jobs, keyed by their names and indices. Not all jobs have names, so
|
||||||
|
// this is a subset of the jobs stored in `jobs_`.
|
||||||
|
absl::flat_hash_map<NamedJobKey, std::shared_ptr<Job>> named_jobs_
|
||||||
|
TF_GUARDED_BY(mu_);
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
|
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl);
|
||||||
};
|
};
|
||||||
|
@ -47,8 +47,11 @@ namespace data {
|
|||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
||||||
|
/* static */ constexpr const char* const DataServiceDatasetOp::kJobName;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
DataServiceDatasetOp::kMaxOutstandingRequests;
|
DataServiceDatasetOp::kMaxOutstandingRequests;
|
||||||
|
/* static */ constexpr const char* const
|
||||||
|
DataServiceDatasetOp::kIterationCounter;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
|
||||||
|
|
||||||
@ -71,23 +74,45 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, int64 dataset_id,
|
Dataset(OpKernelContext* ctx, int64 dataset_id,
|
||||||
ProcessingMode processing_mode, const std::string& address,
|
ProcessingMode processing_mode, const std::string& address,
|
||||||
const std::string& protocol, int64 max_outstanding_requests,
|
const std::string& protocol, const std::string& job_name,
|
||||||
int64 task_refresh_interval_ms, const DataTypeVector& output_types,
|
int64 max_outstanding_requests, int64 task_refresh_interval_ms,
|
||||||
|
IterationCounter* iteration_counter, bool owns_resource,
|
||||||
|
ResourceHandle iteration_counter_handle,
|
||||||
|
const DataTypeVector& output_types,
|
||||||
const std::vector<PartialTensorShape>& output_shapes)
|
const std::vector<PartialTensorShape>& output_shapes)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
dataset_id_(dataset_id),
|
dataset_id_(dataset_id),
|
||||||
processing_mode_(processing_mode),
|
processing_mode_(processing_mode),
|
||||||
address_(address),
|
address_(address),
|
||||||
protocol_(protocol),
|
protocol_(protocol),
|
||||||
|
job_name_(job_name),
|
||||||
max_outstanding_requests_(max_outstanding_requests),
|
max_outstanding_requests_(max_outstanding_requests),
|
||||||
task_refresh_interval_ms_(task_refresh_interval_ms),
|
task_refresh_interval_ms_(task_refresh_interval_ms),
|
||||||
|
iteration_counter_(iteration_counter),
|
||||||
|
owns_resource_(owns_resource),
|
||||||
|
iteration_counter_handle_(iteration_counter_handle),
|
||||||
|
resource_mgr_(ctx->resource_manager()),
|
||||||
output_types_(output_types),
|
output_types_(output_types),
|
||||||
output_shapes_(output_shapes) {}
|
output_shapes_(output_shapes) {}
|
||||||
|
|
||||||
|
~Dataset() override {
|
||||||
|
iteration_counter_->Unref();
|
||||||
|
if (owns_resource_) {
|
||||||
|
Status s = resource_mgr_->Delete<IterationCounter>(
|
||||||
|
iteration_counter_handle_.container(),
|
||||||
|
iteration_counter_handle_.name());
|
||||||
|
if (!s.ok()) {
|
||||||
|
LOG(WARNING) << "Failed to delete iteration counter resource: " << s;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
std::unique_ptr<IteratorBase> MakeIteratorInternal(
|
||||||
const string& prefix) const override {
|
const string& prefix) const override {
|
||||||
return absl::make_unique<Iterator>(Iterator::Params{
|
return absl::make_unique<Iterator>(
|
||||||
this, name_utils::IteratorPrefix(kDatasetType, prefix)});
|
Iterator::Params{this,
|
||||||
|
name_utils::IteratorPrefix(kDatasetType, prefix)},
|
||||||
|
iteration_counter_->GetAndIncrement());
|
||||||
}
|
}
|
||||||
|
|
||||||
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
const DataTypeVector& output_dtypes() const override { return output_types_; }
|
||||||
@ -123,18 +148,26 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
Node* protocol;
|
Node* protocol;
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
|
TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol));
|
||||||
|
|
||||||
|
Node* job_name;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name));
|
||||||
|
|
||||||
Node* max_outstanding_requests;
|
Node* max_outstanding_requests;
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
|
b->AddScalar(max_outstanding_requests_, &max_outstanding_requests));
|
||||||
|
|
||||||
|
Node* iteration_counter_handle = nullptr;
|
||||||
|
Tensor handle(DT_RESOURCE, TensorShape({}));
|
||||||
|
handle.scalar<ResourceHandle>()() = iteration_counter_handle_;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle));
|
||||||
|
|
||||||
AttrValue task_refresh_interval_hint_ms;
|
AttrValue task_refresh_interval_hint_ms;
|
||||||
b->BuildAttrValue(task_refresh_interval_ms_,
|
b->BuildAttrValue(task_refresh_interval_ms_,
|
||||||
&task_refresh_interval_hint_ms);
|
&task_refresh_interval_hint_ms);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
b->AddDataset(this,
|
b->AddDataset(this,
|
||||||
{dataset_id, processing_mode, address, protocol,
|
{dataset_id, processing_mode, address, protocol, job_name,
|
||||||
max_outstanding_requests},
|
max_outstanding_requests, iteration_counter_handle},
|
||||||
{std::make_pair(kTaskRefreshIntervalHintMs,
|
{std::make_pair(kTaskRefreshIntervalHintMs,
|
||||||
task_refresh_interval_hint_ms)},
|
task_refresh_interval_hint_ms)},
|
||||||
output));
|
output));
|
||||||
@ -144,8 +177,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
private:
|
private:
|
||||||
class Iterator : public DatasetIterator<Dataset> {
|
class Iterator : public DatasetIterator<Dataset> {
|
||||||
public:
|
public:
|
||||||
explicit Iterator(const Params& params)
|
explicit Iterator(const Params& params, int64 iterator_index)
|
||||||
: DatasetIterator<Dataset>(params) {}
|
: DatasetIterator<Dataset>(params), iterator_index_(iterator_index) {}
|
||||||
|
|
||||||
~Iterator() override {
|
~Iterator() override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
@ -161,8 +194,14 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
VLOG(3) << "Connecting to " << dataset()->address_
|
VLOG(3) << "Connecting to " << dataset()->address_
|
||||||
<< " in data service dataset op";
|
<< " in data service dataset op";
|
||||||
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
||||||
TF_RETURN_IF_ERROR(master.CreateJob(
|
if (dataset()->job_name_.empty()) {
|
||||||
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
TF_RETURN_IF_ERROR(master.CreateJob(
|
||||||
|
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
||||||
|
} else {
|
||||||
|
TF_RETURN_IF_ERROR(master.GetOrCreateJob(
|
||||||
|
dataset()->dataset_id_, dataset()->processing_mode_,
|
||||||
|
dataset()->job_name_, iterator_index_, &job_id_));
|
||||||
|
}
|
||||||
VLOG(1) << "Created data service job with id " << job_id_;
|
VLOG(1) << "Created data service job with id " << job_id_;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -418,6 +457,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64 iterator_index_;
|
||||||
|
|
||||||
mutex mu_;
|
mutex mu_;
|
||||||
// TODO(aaudibert): split this into a couple cvs for different conditions
|
// TODO(aaudibert): split this into a couple cvs for different conditions
|
||||||
// so that we can use notify_one and avoid unnecessary wakeups.
|
// so that we can use notify_one and avoid unnecessary wakeups.
|
||||||
@ -449,8 +490,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
const ProcessingMode processing_mode_;
|
const ProcessingMode processing_mode_;
|
||||||
const tstring address_;
|
const tstring address_;
|
||||||
const tstring protocol_;
|
const tstring protocol_;
|
||||||
|
const tstring job_name_;
|
||||||
const int64 max_outstanding_requests_;
|
const int64 max_outstanding_requests_;
|
||||||
const int64 task_refresh_interval_ms_;
|
const int64 task_refresh_interval_ms_;
|
||||||
|
IterationCounter* const iteration_counter_; // Owned
|
||||||
|
const bool owns_resource_;
|
||||||
|
const ResourceHandle iteration_counter_handle_;
|
||||||
|
ResourceMgr* const resource_mgr_; // Not owned
|
||||||
const DataTypeVector output_types_;
|
const DataTypeVector output_types_;
|
||||||
const std::vector<PartialTensorShape> output_shapes_;
|
const std::vector<PartialTensorShape> output_shapes_;
|
||||||
};
|
};
|
||||||
@ -488,9 +534,41 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
|||||||
OP_REQUIRES(ctx, !protocol.empty(),
|
OP_REQUIRES(ctx, !protocol.empty(),
|
||||||
errors::InvalidArgument(kProtocol, " must be non-empty."));
|
errors::InvalidArgument(kProtocol, " must be non-empty."));
|
||||||
|
|
||||||
|
tstring job_name;
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name));
|
||||||
|
|
||||||
int64 max_outstanding_requests;
|
int64 max_outstanding_requests;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests,
|
||||||
&max_outstanding_requests));
|
&max_outstanding_requests));
|
||||||
|
|
||||||
|
ResourceHandle iteration_counter_handle;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle));
|
||||||
|
IterationCounter* iteration_counter = nullptr;
|
||||||
|
Status s = ctx->resource_manager()->Lookup<IterationCounter>(
|
||||||
|
iteration_counter_handle.container(), iteration_counter_handle.name(),
|
||||||
|
&iteration_counter);
|
||||||
|
bool owns_resource = false;
|
||||||
|
if (errors::IsNotFound(s)) {
|
||||||
|
owns_resource = true;
|
||||||
|
static std::atomic<int64> resource_id_counter(0);
|
||||||
|
const std::string& container = ctx->resource_manager()->default_container();
|
||||||
|
std::string name =
|
||||||
|
strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_",
|
||||||
|
resource_id_counter.fetch_add(1));
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->resource_manager()->LookupOrCreate<IterationCounter>(
|
||||||
|
container, name, &iteration_counter,
|
||||||
|
[](IterationCounter** counter) {
|
||||||
|
*counter = new IterationCounter();
|
||||||
|
return Status::OK();
|
||||||
|
}));
|
||||||
|
iteration_counter_handle =
|
||||||
|
MakeResourceHandle<IterationCounter>(ctx, container, name);
|
||||||
|
} else {
|
||||||
|
OP_REQUIRES_OK(ctx, s);
|
||||||
|
}
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx,
|
ctx,
|
||||||
max_outstanding_requests == model::kAutotune ||
|
max_outstanding_requests == model::kAutotune ||
|
||||||
@ -499,13 +577,16 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
|||||||
model::kAutotune));
|
model::kAutotune));
|
||||||
|
|
||||||
*output =
|
*output =
|
||||||
new Dataset(ctx, dataset_id, processing_mode, address, protocol,
|
new Dataset(ctx, dataset_id, processing_mode, address, protocol, job_name,
|
||||||
max_outstanding_requests, task_refresh_interval_hint_ms_,
|
max_outstanding_requests, task_refresh_interval_hint_ms_,
|
||||||
|
iteration_counter, owns_resource, iteration_counter_handle,
|
||||||
output_types_, output_shapes_);
|
output_types_, output_shapes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
||||||
DataServiceDatasetOp);
|
DataServiceDatasetOp);
|
||||||
|
REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU),
|
||||||
|
DummyResourceOp<IterationCounter>);
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -15,11 +15,34 @@ limitations under the License.
|
|||||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
|
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
|
||||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
|
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/core/framework/dataset.h"
|
#include "tensorflow/core/framework/dataset.h"
|
||||||
|
#include "tensorflow/core/framework/resource_mgr.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
|
// A resource which counts how many iterators have been created. This is used
|
||||||
|
// by the DataServiceDataset to coordinate jobs across multiple iterations.
|
||||||
|
class IterationCounter : public ResourceBase {
|
||||||
|
public:
|
||||||
|
IterationCounter() : counter_(0) {}
|
||||||
|
|
||||||
|
std::string DebugString() const override {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return absl::StrCat(counter_);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 GetAndIncrement() {
|
||||||
|
mutex_lock l(mu_);
|
||||||
|
return ++counter_;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
mutable mutex mu_;
|
||||||
|
int64 counter_ TF_GUARDED_BY(mu_) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
// Creates a dataset for reading from the tf.data service.
|
// Creates a dataset for reading from the tf.data service.
|
||||||
class DataServiceDatasetOp : public DatasetOpKernel {
|
class DataServiceDatasetOp : public DatasetOpKernel {
|
||||||
public:
|
public:
|
||||||
@ -28,10 +51,12 @@ class DataServiceDatasetOp : public DatasetOpKernel {
|
|||||||
static constexpr const char* const kProcessingMode = "processing_mode";
|
static constexpr const char* const kProcessingMode = "processing_mode";
|
||||||
static constexpr const char* const kAddress = "address";
|
static constexpr const char* const kAddress = "address";
|
||||||
static constexpr const char* const kProtocol = "protocol";
|
static constexpr const char* const kProtocol = "protocol";
|
||||||
|
static constexpr const char* const kJobName = "job_name";
|
||||||
static constexpr const char* const kMaxOutstandingRequests =
|
static constexpr const char* const kMaxOutstandingRequests =
|
||||||
"max_outstanding_requests";
|
"max_outstanding_requests";
|
||||||
static constexpr const char* const kTaskRefreshIntervalHintMs =
|
static constexpr const char* const kTaskRefreshIntervalHintMs =
|
||||||
"task_refresh_interval_hint_ms";
|
"task_refresh_interval_hint_ms";
|
||||||
|
static constexpr const char* const kIterationCounter = "iteration_counter";
|
||||||
static constexpr const char* const kOutputTypes = "output_types";
|
static constexpr const char* const kOutputTypes = "output_types";
|
||||||
static constexpr const char* const kOutputShapes = "output_shapes";
|
static constexpr const char* const kOutputShapes = "output_shapes";
|
||||||
|
|
||||||
|
@ -1,47 +0,0 @@
|
|||||||
op {
|
|
||||||
name: "DataServiceDataset"
|
|
||||||
input_arg {
|
|
||||||
name: "dataset_id"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "processing_mode"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "address"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "protocol"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "max_outstanding_requests"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "handle"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "task_refresh_interval_hint_ms"
|
|
||||||
type: "int"
|
|
||||||
default_value {
|
|
||||||
i: -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "output_types"
|
|
||||||
type: "list(type)"
|
|
||||||
has_minimum: true
|
|
||||||
minimum: 1
|
|
||||||
}
|
|
||||||
attr {
|
|
||||||
name: "output_shapes"
|
|
||||||
type: "list(shape)"
|
|
||||||
has_minimum: true
|
|
||||||
minimum: 1
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
@ -1037,12 +1037,21 @@ REGISTER_OP("ExperimentalUniqueDataset")
|
|||||||
.Attr("output_shapes: list(shape) >= 1")
|
.Attr("output_shapes: list(shape) >= 1")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
|
REGISTER_OP("DummyIterationCounter")
|
||||||
|
.Output("handle: resource")
|
||||||
|
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||||
|
c->set_output(0, c->Scalar());
|
||||||
|
return Status::OK();
|
||||||
|
});
|
||||||
|
|
||||||
REGISTER_OP("DataServiceDataset")
|
REGISTER_OP("DataServiceDataset")
|
||||||
.Input("dataset_id: int64")
|
.Input("dataset_id: int64")
|
||||||
.Input("processing_mode: string")
|
.Input("processing_mode: string")
|
||||||
.Input("address: string")
|
.Input("address: string")
|
||||||
.Input("protocol: string")
|
.Input("protocol: string")
|
||||||
|
.Input("job_name: string")
|
||||||
.Input("max_outstanding_requests: int64")
|
.Input("max_outstanding_requests: int64")
|
||||||
|
.Input("iteration_counter: resource")
|
||||||
.Output("handle: variant")
|
.Output("handle: variant")
|
||||||
.Attr("task_refresh_interval_hint_ms: int = -1")
|
.Attr("task_refresh_interval_hint_ms: int = -1")
|
||||||
.Attr("output_types: list(type) >= 1")
|
.Attr("output_types: list(type) >= 1")
|
||||||
|
@ -51,6 +51,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
processing_mode,
|
processing_mode,
|
||||||
address,
|
address,
|
||||||
protocol,
|
protocol,
|
||||||
|
job_name=None,
|
||||||
max_outstanding_requests=None,
|
max_outstanding_requests=None,
|
||||||
task_refresh_interval_hint_ms=None):
|
task_refresh_interval_hint_ms=None):
|
||||||
"""Constructs a _DataServiceDatasetV2.
|
"""Constructs a _DataServiceDatasetV2.
|
||||||
@ -65,6 +66,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
address: The tf.data service address, e.g. "localhost:5000".
|
address: The tf.data service address, e.g. "localhost:5000".
|
||||||
protocol: The protocol to use for communicating with the tf.data service,
|
protocol: The protocol to use for communicating with the tf.data service,
|
||||||
e.g. "grpc".
|
e.g. "grpc".
|
||||||
|
job_name: (Optional.) The name of the job. This argument makes it
|
||||||
|
possible for multiple datasets to share the same job. The default
|
||||||
|
behavior is that the dataset creates anonymous, exclusively owned jobs.
|
||||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||||
requested at the same time. You can use this option to control the
|
requested at the same time. You can use this option to control the
|
||||||
amount of memory used, since `distribute` won't use more than
|
amount of memory used, since `distribute` won't use more than
|
||||||
@ -73,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
the master for task changes.
|
the master for task changes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if job_name is None:
|
||||||
|
job_name = ""
|
||||||
if max_outstanding_requests is None:
|
if max_outstanding_requests is None:
|
||||||
max_outstanding_requests = dataset_ops.AUTOTUNE
|
max_outstanding_requests = dataset_ops.AUTOTUNE
|
||||||
if task_refresh_interval_hint_ms is None:
|
if task_refresh_interval_hint_ms is None:
|
||||||
@ -85,8 +91,11 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
processing_mode=processing_mode,
|
processing_mode=processing_mode,
|
||||||
address=address,
|
address=address,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
|
job_name=job_name,
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms,
|
||||||
|
iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter(
|
||||||
|
),
|
||||||
**self._flat_structure)
|
**self._flat_structure)
|
||||||
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@ -100,7 +109,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
|||||||
|
|
||||||
@functools.wraps(_DataServiceDatasetV2.__init__)
|
@functools.wraps(_DataServiceDatasetV2.__init__)
|
||||||
def __init__(self, input_dataset, dataset_id, processing_mode, address,
|
def __init__(self, input_dataset, dataset_id, processing_mode, address,
|
||||||
protocol, max_outstanding_requests,
|
protocol, job_name, max_outstanding_requests,
|
||||||
task_refresh_interval_hint_ms):
|
task_refresh_interval_hint_ms):
|
||||||
|
|
||||||
self._wrapped = _DataServiceDatasetV2(
|
self._wrapped = _DataServiceDatasetV2(
|
||||||
@ -109,6 +118,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
|||||||
processing_mode=processing_mode,
|
processing_mode=processing_mode,
|
||||||
address=address,
|
address=address,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
|
job_name=job_name,
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||||
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
|
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
|
||||||
@ -122,6 +132,7 @@ else:
|
|||||||
|
|
||||||
def _distribute(processing_mode,
|
def _distribute(processing_mode,
|
||||||
service,
|
service,
|
||||||
|
job_name=None,
|
||||||
max_outstanding_requests=None,
|
max_outstanding_requests=None,
|
||||||
task_refresh_interval_hint_ms=None):
|
task_refresh_interval_hint_ms=None):
|
||||||
"""A transformation that moves dataset processing to the tf.data service.
|
"""A transformation that moves dataset processing to the tf.data service.
|
||||||
@ -136,6 +147,9 @@ def _distribute(processing_mode,
|
|||||||
service: A string indicating how to connect to the tf.data service. The
|
service: A string indicating how to connect to the tf.data service. The
|
||||||
string should be in the format <protocol>://<address>, e.g.
|
string should be in the format <protocol>://<address>, e.g.
|
||||||
grpc://localhost:5000.
|
grpc://localhost:5000.
|
||||||
|
job_name: (Optional.) The name of the job. This argument makes it
|
||||||
|
possible for multiple datasets to share the same job. The default behavior
|
||||||
|
is that the dataset creates anonymous, exclusively owned jobs.
|
||||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||||
requested at the same time. You can use this option to control the amount
|
requested at the same time. You can use this option to control the amount
|
||||||
of memory used, since `distribute` won't use more than `element_size` *
|
of memory used, since `distribute` won't use more than `element_size` *
|
||||||
@ -147,6 +161,12 @@ def _distribute(processing_mode,
|
|||||||
Dataset: A `Dataset` of the elements produced by the data service.
|
Dataset: A `Dataset` of the elements produced by the data service.
|
||||||
"""
|
"""
|
||||||
ProcessingMode.validate(processing_mode)
|
ProcessingMode.validate(processing_mode)
|
||||||
|
if job_name is not None:
|
||||||
|
if not isinstance(job_name, six.string_types):
|
||||||
|
raise ValueError("job_name must be a string, but job_name was of type "
|
||||||
|
"{0}. job_name={1}".format(type(job_name), job_name))
|
||||||
|
if not job_name:
|
||||||
|
raise ValueError("job_name must not be empty")
|
||||||
if not isinstance(service, six.string_types):
|
if not isinstance(service, six.string_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"service must be a string, but service was of type {0}. service={1}"
|
"service must be a string, but service was of type {0}. service={1}"
|
||||||
@ -182,41 +202,49 @@ def _distribute(processing_mode,
|
|||||||
processing_mode=processing_mode,
|
processing_mode=processing_mode,
|
||||||
address=address,
|
address=address,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
|
job_name=job_name,
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||||
|
|
||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
def distribute(processing_mode, service, max_outstanding_requests=None):
|
def distribute(processing_mode,
|
||||||
|
service,
|
||||||
|
job_name=None,
|
||||||
|
max_outstanding_requests=None):
|
||||||
"""A transformation that moves dataset processing to the tf.data service.
|
"""A transformation that moves dataset processing to the tf.data service.
|
||||||
|
|
||||||
The `processing_mode` argument controls how data is processed by the
|
When you iterate over a dataset containing the `distribute` transformation,
|
||||||
tf.data service. Currently, the only supported mode is "parallel_epochs".
|
the tf.data service creates a "job" which produces data for the dataset
|
||||||
|
iteration.
|
||||||
|
|
||||||
|
The `processing_mode` argument controls what data is produced by a tf.data
|
||||||
|
service job. Currently, the only supported mode is "parallel_epochs".
|
||||||
|
|
||||||
processing_mode="parallel_epochs" means that multiple tf.data workers will
|
processing_mode="parallel_epochs" means that multiple tf.data workers will
|
||||||
iterate through the dataset in parallel, each producing all elements of the
|
iterate through the dataset in parallel, each producing all elements of the
|
||||||
dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker
|
dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker
|
||||||
used for execution will produce {0, 1, 2}. If there are 3 workers and one
|
used for execution will produce {0, 1, 2}. If there are 3 workers, the job
|
||||||
consumer, the consumer will receive the elements {0, 0, 0, 1, 1, 1, 2, 2, 2}
|
will produce the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} (though not necessarily
|
||||||
(though not necessarily in that order). To account for this, it is recommended
|
in that order). To account for this, it is recommended to randomly shuffle
|
||||||
to randomly shuffle your dataset, so that different tf.data workers will
|
your dataset, so that different tf.data workers will iterate through the
|
||||||
iterate through the dataset in different orders.
|
dataset in different orders.
|
||||||
|
|
||||||
In the future, there will be additional epoch modes. For example,
|
In the future, there will be additional processing modes. For example,
|
||||||
a "one_epoch" mode which partitions the dataset across the tf.data
|
a "one_epoch" mode which partitions the dataset across the tf.data
|
||||||
workers, so that the consumers see each element of the dataset only once.
|
workers, so that the consumers see each element of the dataset only once.
|
||||||
|
|
||||||
```
|
```
|
||||||
dataset = tf.data.Dataset.range(10)
|
dataset = tf.data.Dataset.range(5)
|
||||||
dataset = dataset.map(lambda x: x*x)
|
dataset = dataset.map(lambda x: x*x)
|
||||||
dataset = dataset.apply(
|
dataset = dataset.apply(
|
||||||
tf.data.experimental.service.distribute("parallel_epochs",
|
tf.data.experimental.service.distribute("parallel_epochs",
|
||||||
"grpc://dataservice:5000"))
|
"grpc://dataservice:5000"))
|
||||||
dataset = dataset.map(lambda x: x+10)
|
dataset = dataset.map(lambda x: x+1)
|
||||||
|
|
||||||
for element in dataset:
|
for element in dataset:
|
||||||
# process element
|
print(element) # prints { 1, 2, 5, 10, 17 }
|
||||||
```
|
```
|
||||||
|
|
||||||
In the above example, the first two lines (before the call to `distribute`)
|
In the above example, the first two lines (before the call to `distribute`)
|
||||||
@ -224,6 +252,33 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
|
|||||||
RPC. The remaining transformations (after the call to `distribute`) will be
|
RPC. The remaining transformations (after the call to `distribute`) will be
|
||||||
executed locally.
|
executed locally.
|
||||||
|
|
||||||
|
The `job_name` argument allows jobs to be shared across multiple
|
||||||
|
datasets. Instead of each dataset creating its own job, all datasets with the
|
||||||
|
same `job_name` will consume from the same job. A new job will
|
||||||
|
be created for each iteration of the dataset (with each repetition of
|
||||||
|
`Dataset.repeat` counting as a new iteration). The following example
|
||||||
|
demonstrates shared iteration, with the assumption that the tf.data service is
|
||||||
|
running with a single worker.
|
||||||
|
|
||||||
|
```
|
||||||
|
range5_dataset = tf.data.Dataset.range(5)
|
||||||
|
dataset1 = range5_dataset.apply(tf.data.experimental.service.distribute(
|
||||||
|
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
|
||||||
|
dataset2 = range5_dataset.apply(tf.data.experimental.service.distribute(
|
||||||
|
"parallel_epochs", "my_job_name", "grpc://dataservice:5000"))
|
||||||
|
iter_1_1 = iter(dataset1)
|
||||||
|
iter_1_2 = iter(dataset1)
|
||||||
|
iter_2_1 = iter(dataset2)
|
||||||
|
iter_2_2 = iter(dataset2)
|
||||||
|
print(next(iter_1_1)) # Prints "0"
|
||||||
|
# iter_1_2 consumes from the same job as iter_1_1
|
||||||
|
print(next(iter_1_2)) # Prints "1"
|
||||||
|
# iter_2_1 consumes from a new job
|
||||||
|
print(next(iter_2_1)) # Prints "0"
|
||||||
|
# iter_2_2 consumes from the same job as iter_2_1
|
||||||
|
print(next(iter_2_2)) # Prints "1"
|
||||||
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
processing_mode: A string specifying the policy for how data should be
|
processing_mode: A string specifying the policy for how data should be
|
||||||
processed by tf.data workers. Currently, the only supported value is
|
processed by tf.data workers. Currently, the only supported value is
|
||||||
@ -231,6 +286,9 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
|
|||||||
service: A string indicating how to connect to the tf.data service. The
|
service: A string indicating how to connect to the tf.data service. The
|
||||||
string should be in the format <protocol>://<address>, e.g.
|
string should be in the format <protocol>://<address>, e.g.
|
||||||
grpc://localhost:5000.
|
grpc://localhost:5000.
|
||||||
|
job_name: (Optional.) The name of the job. This argument makes it possible
|
||||||
|
for multiple datasets to share the same job. The default behavior is that
|
||||||
|
the dataset creates anonymous, exclusively owned jobs.
|
||||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||||
requested at the same time. You can use this option to control the amount
|
requested at the same time. You can use this option to control the amount
|
||||||
of memory used, since `distribute` won't use more than `element_size` *
|
of memory used, since `distribute` won't use more than `element_size` *
|
||||||
@ -239,4 +297,5 @@ def distribute(processing_mode, service, max_outstanding_requests=None):
|
|||||||
Returns:
|
Returns:
|
||||||
Dataset: A `Dataset` of the elements produced by the data service.
|
Dataset: A `Dataset` of the elements produced by the data service.
|
||||||
"""
|
"""
|
||||||
return _distribute(processing_mode, service, max_outstanding_requests)
|
return _distribute(processing_mode, service, job_name,
|
||||||
|
max_outstanding_requests)
|
||||||
|
@ -37,11 +37,14 @@ from tensorflow.python.platform import test
|
|||||||
PROTOCOL = "grpc"
|
PROTOCOL = "grpc"
|
||||||
|
|
||||||
|
|
||||||
def _make_distributed_dataset(dataset, service):
|
def _make_distributed_dataset(dataset, service, job_name=None):
|
||||||
"""Creates a distributed dataset with a short task refresh interval."""
|
"""Creates a distributed dataset with a short task refresh interval."""
|
||||||
return dataset.apply(
|
return dataset.apply(
|
||||||
data_service_ops._distribute(
|
data_service_ops._distribute(
|
||||||
"parallel_epochs", service, task_refresh_interval_hint_ms=20))
|
"parallel_epochs",
|
||||||
|
service,
|
||||||
|
job_name=job_name,
|
||||||
|
task_refresh_interval_hint_ms=20))
|
||||||
|
|
||||||
|
|
||||||
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@ -233,6 +236,72 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
result = list(f().numpy())
|
result = list(f().numpy())
|
||||||
self.assertCountEqual(num_workers * list(range(num_elements)), result)
|
self.assertCountEqual(num_workers * list(range(num_elements)), result)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testSharedJobName(self):
|
||||||
|
num_elements = 10
|
||||||
|
service = self.create_cluster(1)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||||
|
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||||
|
iter1 = iter(ds1)
|
||||||
|
iter2 = iter(ds2)
|
||||||
|
results = []
|
||||||
|
for _ in range(3):
|
||||||
|
results.append(next(iter1).numpy())
|
||||||
|
results.append(next(iter2).numpy())
|
||||||
|
for elem in iter1:
|
||||||
|
results.append(elem.numpy())
|
||||||
|
for elem in iter2:
|
||||||
|
results.append(elem.numpy())
|
||||||
|
self.assertCountEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDifferentJobNames(self):
|
||||||
|
num_elements = 10
|
||||||
|
service = self.create_cluster(1)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds1 = _make_distributed_dataset(ds, service, job_name="job_name1")
|
||||||
|
ds2 = _make_distributed_dataset(ds, service, job_name="job_name2")
|
||||||
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
|
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testSharedJobNameMultiIteration(self):
|
||||||
|
num_elements = 10
|
||||||
|
service = self.create_cluster(1)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||||
|
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||||
|
# iteration 1
|
||||||
|
self.assertDatasetProduces(ds1, list(range(num_elements)))
|
||||||
|
self.assertDatasetProduces(ds2, [])
|
||||||
|
# iteration 2
|
||||||
|
self.assertDatasetProduces(ds2, list(range(num_elements)))
|
||||||
|
self.assertDatasetProduces(ds1, [])
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testSharedJobNameRepeat(self):
|
||||||
|
num_elements = 10
|
||||||
|
num_repetitions = 3
|
||||||
|
service = self.create_cluster(1)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds1 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||||
|
ds1 = ds1.repeat(num_repetitions)
|
||||||
|
ds2 = _make_distributed_dataset(ds, service, job_name="job_name")
|
||||||
|
ds2 = ds2.repeat(num_repetitions)
|
||||||
|
results = []
|
||||||
|
iter1 = iter(ds1)
|
||||||
|
iter2 = iter(ds2)
|
||||||
|
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
||||||
|
results.append(next(iter1).numpy())
|
||||||
|
for _ in range(((num_elements * num_repetitions) // 2) - 1):
|
||||||
|
results.append(next(iter2).numpy())
|
||||||
|
for elem in iter1:
|
||||||
|
results.append(elem.numpy())
|
||||||
|
for elem in iter2:
|
||||||
|
results.append(elem.numpy())
|
||||||
|
self.assertCountEqual(num_repetitions * list(range(num_elements)), results)
|
||||||
|
|
||||||
def run_stateful(self, external_state_policy):
|
def run_stateful(self, external_state_policy):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
ds = dataset_ops.Dataset.range(num_elements).map(
|
ds = dataset_ops.Dataset.range(num_elements).map(
|
||||||
|
@ -934,7 +934,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DataServiceDataset"
|
name: "DataServiceDataset"
|
||||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DatasetCardinality"
|
name: "DatasetCardinality"
|
||||||
@ -1176,6 +1176,10 @@ tf_module {
|
|||||||
name: "DrawBoundingBoxesV2"
|
name: "DrawBoundingBoxesV2"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DummyIterationCounter"
|
||||||
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DummyMemoryCache"
|
name: "DummyMemoryCache"
|
||||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -934,7 +934,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DataServiceDataset"
|
name: "DataServiceDataset"
|
||||||
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DatasetCardinality"
|
name: "DatasetCardinality"
|
||||||
@ -1176,6 +1176,10 @@ tf_module {
|
|||||||
name: "DrawBoundingBoxesV2"
|
name: "DrawBoundingBoxesV2"
|
||||||
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
|
member_method {
|
||||||
|
name: "DummyIterationCounter"
|
||||||
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DummyMemoryCache"
|
name: "DummyMemoryCache"
|
||||||
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user