From 22546b562db7c002e1ab969cd130138c920fcede Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Tue, 5 May 2020 15:54:30 -0700 Subject: [PATCH] [tf.data service] Add support for shared job names. Shared job names give users a way to share tf.data service job output across multiple datasets. PiperOrigin-RevId: 310037762 Change-Id: I5a8f9130806a361ada46b3f0b62ea71b0f9b2e8e --- .../api_def_DummyIterationCounter.pbtxt | 4 + tensorflow/core/data/service/BUILD | 1 + tensorflow/core/data/service/data_service.cc | 46 +++++-- tensorflow/core/data/service/data_service.h | 21 ++- .../core/data/service/grpc_master_impl.cc | 1 + .../core/data/service/grpc_master_impl.h | 1 + tensorflow/core/data/service/master.proto | 22 +++- tensorflow/core/data/service/master_impl.cc | 121 +++++++++++++----- tensorflow/core/data/service/master_impl.h | 48 ++++++- .../experimental/data_service_dataset_op.cc | 103 +++++++++++++-- .../experimental/data_service_dataset_op.h | 25 ++++ .../ops_history_v2/DataServiceDataset.pbtxt | 47 ------- .../core/ops/experimental_dataset_ops.cc | 9 ++ .../data/experimental/ops/data_service_ops.py | 87 +++++++++++-- .../kernel_tests/data_service_ops_test.py | 73 ++++++++++- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 6 +- .../api/golden/v2/tensorflow.raw_ops.pbtxt | 6 +- 17 files changed, 493 insertions(+), 128 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_DummyIterationCounter.pbtxt delete mode 100644 tensorflow/core/ops/compat/ops_history_v2/DataServiceDataset.pbtxt diff --git a/tensorflow/core/api_def/base_api/api_def_DummyIterationCounter.pbtxt b/tensorflow/core/api_def/base_api/api_def_DummyIterationCounter.pbtxt new file mode 100644 index 00000000000..dcaf11ef54b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_DummyIterationCounter.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "DummyIterationCounter" + visibility: HIDDEN +} diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD index 4a973423519..5413493cb78 100644 --- a/tensorflow/core/data/service/BUILD +++ b/tensorflow/core/data/service/BUILD @@ -56,6 +56,7 @@ cc_library( deps = [ ":common_proto_cc", ":credentials_factory", + ":data_service", ":grpc_util", ":master_proto_cc", ":worker_cc_grpc_proto", diff --git a/tensorflow/core/data/service/data_service.cc b/tensorflow/core/data/service/data_service.cc index 688e3214a47..915435d8fcb 100644 --- a/tensorflow/core/data/service/data_service.cc +++ b/tensorflow/core/data/service/data_service.cc @@ -31,7 +31,7 @@ constexpr const char kParallelEpochs[] = "parallel_epochs"; constexpr const char kOneEpoch[] = "one_epoch"; } // namespace -Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) { +Status ParseProcessingMode(const std::string& s, ProcessingMode* mode) { if (s == kParallelEpochs) { *mode = ProcessingMode::PARALLEL_EPOCHS; } else if (s == kOneEpoch) { @@ -54,6 +54,21 @@ std::string ProcessingModeToString(ProcessingMode mode) { } } +Status DataServiceMasterClient::RegisterDataset(GraphDef dataset, + int64* dataset_id) { + TF_RETURN_IF_ERROR(EnsureInitialized()); + GetOrRegisterDatasetRequest req; + *req.mutable_dataset()->mutable_graph() = dataset; + GetOrRegisterDatasetResponse resp; + grpc::ClientContext client_ctx; + grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp); + if (!status.ok()) { + return grpc_util::WrapError("Failed to register dataset", status); + } + *dataset_id = resp.dataset_id(); + return Status::OK(); +} + Status DataServiceMasterClient::CreateJob(int64 dataset_id, ProcessingMode processing_mode, int64* job_id) { @@ -73,18 +88,27 @@ Status DataServiceMasterClient::CreateJob(int64 dataset_id, return Status::OK(); } -Status DataServiceMasterClient::RegisterDataset(GraphDef dataset, - int64* dataset_id) { +Status DataServiceMasterClient::GetOrCreateJob(int64 dataset_id, + ProcessingMode processing_mode, + const std::string& job_name, + int job_name_index, + int64* job_id) { TF_RETURN_IF_ERROR(EnsureInitialized()); - GetOrRegisterDatasetRequest req; - *req.mutable_dataset()->mutable_graph() = dataset; - GetOrRegisterDatasetResponse resp; + GetOrCreateJobRequest req; + req.set_dataset_id(dataset_id); + req.set_processing_mode(ProcessingModeDef(processing_mode)); + req.set_job_name(job_name); + req.set_job_name_index(job_name_index); + GetOrCreateJobResponse resp; grpc::ClientContext client_ctx; - grpc::Status status = stub_->GetOrRegisterDataset(&client_ctx, req, &resp); + grpc::Status status = stub_->GetOrCreateJob(&client_ctx, req, &resp); if (!status.ok()) { - return grpc_util::WrapError("Failed to register dataset", status); + return grpc_util::WrapError( + absl::StrCat("Failed to get or create job for dataset with id ", + dataset_id), + status); } - *dataset_id = resp.dataset_id(); + *job_id = resp.job_id(); return Status::OK(); } @@ -148,7 +172,7 @@ Status DataServiceWorkerClient::EnsureInitialized() { } Status CreateDataServiceMasterClient( - absl::string_view address, absl::string_view protocol, + const std::string& address, const std::string& protocol, std::unique_ptr* out) { auto client = absl::make_unique(address, protocol); TF_RETURN_IF_ERROR(client->Initialize()); @@ -157,7 +181,7 @@ Status CreateDataServiceMasterClient( } Status CreateDataServiceWorkerClient( - absl::string_view address, absl::string_view protocol, + const std::string& address, const std::string& protocol, std::unique_ptr* out) { auto client = absl::make_unique(address, protocol); TF_RETURN_IF_ERROR(client->Initialize()); diff --git a/tensorflow/core/data/service/data_service.h b/tensorflow/core/data/service/data_service.h index 009e6d25f60..d205b4d9ebf 100644 --- a/tensorflow/core/data/service/data_service.h +++ b/tensorflow/core/data/service/data_service.h @@ -35,7 +35,7 @@ enum class ProcessingMode : int64 { // Parses a string representing a processing mode and stores the result in // *mode. Returns an InvalidArgument status if the string is not recognized. -Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode); +Status ParseProcessingMode(const std::string& s, ProcessingMode* mode); // Converts a processing mode to its corresponding string. std::string ProcessingModeToString(ProcessingMode mode); @@ -45,7 +45,7 @@ std::string ProcessingModeToString(ProcessingMode mode); // threads. class DataServiceClientBase { public: - DataServiceClientBase(absl::string_view address, absl::string_view protocol) + DataServiceClientBase(const std::string& address, const std::string& protocol) : address_(address), protocol_(protocol) {} virtual ~DataServiceClientBase() = default; @@ -70,7 +70,8 @@ class DataServiceClientBase { // Client for communicating with the tf.data service master. class DataServiceMasterClient : public DataServiceClientBase { public: - DataServiceMasterClient(absl::string_view address, absl::string_view protocol) + DataServiceMasterClient(const std::string& address, + const std::string& protocol) : DataServiceClientBase(address, protocol) {} // Registers a dataset with the tf.data service, and stores the generated @@ -82,6 +83,13 @@ class DataServiceMasterClient : public DataServiceClientBase { Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, int64* job_id); + // Gets the job id for the job represented by the tuple + // (job_name, job_name_index), and stores the id in *job_id. If the + // job doesn't exist yet, it will be created. + Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode, + const std::string& job_name, int job_name_index, + int64* job_id); + // Queries the master for the tasks associated with the specified job. // The tasks will be stored in *tasks, and whether the job is finished will // be stored in `*job_finished`. @@ -98,7 +106,8 @@ class DataServiceMasterClient : public DataServiceClientBase { // Client for communicating with the tf.data service worker. class DataServiceWorkerClient : public DataServiceClientBase { public: - DataServiceWorkerClient(absl::string_view address, absl::string_view protocol) + DataServiceWorkerClient(const std::string& address, + const std::string& protocol) : DataServiceClientBase(address, protocol) {} // Fetches the next element for the specified task_id. The element's @@ -116,12 +125,12 @@ class DataServiceWorkerClient : public DataServiceClientBase { // Creates and initializes a new tf.data service master client. Status CreateDataServiceMasterClient( - absl::string_view address, absl::string_view protocol, + const std::string& address, const std::string& protocol, std::unique_ptr* out); // Creates and initializes a new tf.data service worker client. Status CreateDataServiceWorkerClient( - absl::string_view address, absl::string_view protocol, + const std::string& address, const std::string& protocol, std::unique_ptr* out); } // namespace data diff --git a/tensorflow/core/data/service/grpc_master_impl.cc b/tensorflow/core/data/service/grpc_master_impl.cc index 4e5e9f45cea..ba27959fee7 100644 --- a/tensorflow/core/data/service/grpc_master_impl.cc +++ b/tensorflow/core/data/service/grpc_master_impl.cc @@ -42,6 +42,7 @@ HANDLER(RegisterWorker); HANDLER(WorkerUpdate); HANDLER(GetOrRegisterDataset); HANDLER(CreateJob); +HANDLER(GetOrCreateJob); HANDLER(GetTasks); #undef HANDLER diff --git a/tensorflow/core/data/service/grpc_master_impl.h b/tensorflow/core/data/service/grpc_master_impl.h index 2f775f8fd88..32eb0f3fc6a 100644 --- a/tensorflow/core/data/service/grpc_master_impl.h +++ b/tensorflow/core/data/service/grpc_master_impl.h @@ -46,6 +46,7 @@ class GrpcMasterImpl : public MasterService::Service { HANDLER(WorkerUpdate); HANDLER(GetOrRegisterDataset); HANDLER(CreateJob); + HANDLER(GetOrCreateJob); HANDLER(GetTasks); #undef HANDLER diff --git a/tensorflow/core/data/service/master.proto b/tensorflow/core/data/service/master.proto index 9361b7b6629..005e5affb7d 100644 --- a/tensorflow/core/data/service/master.proto +++ b/tensorflow/core/data/service/master.proto @@ -56,7 +56,24 @@ message CreateJobRequest { } message CreateJobResponse { - // An id for the begun job. + // An id for the created job. + int64 job_id = 1; +} + +message GetOrCreateJobRequest { + // The id of the dataset to create a job for. + int64 dataset_id = 1; + // A mode controlling how the tf.data service produces data for the job. + ProcessingModeDef processing_mode = 2; + // A name for the job. + string job_name = 3; + // An index for the job. Multiple jobs can be created for the same name, if + // they have different indices. + int64 job_name_index = 4; +} + +message GetOrCreateJobResponse { + // The id of the (potentially newly created) job. int64 job_id = 1; } @@ -96,6 +113,9 @@ service MasterService { rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest) returns (GetOrRegisterDatasetResponse); + // Gets a job if it already exists, otherwise creates it. + rpc GetOrCreateJob(GetOrCreateJobRequest) returns (GetOrCreateJobResponse); + // Creates a job for reading from the tf.data service. rpc CreateJob(CreateJobRequest) returns (CreateJobResponse); diff --git a/tensorflow/core/data/service/master_impl.cc b/tensorflow/core/data/service/master_impl.cc index 4141b260ac9..6e2c95c475e 100644 --- a/tensorflow/core/data/service/master_impl.cc +++ b/tensorflow/core/data/service/master_impl.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/core/data/service/common.pb.h" #include "tensorflow/core/data/service/credentials_factory.h" +#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/grpc_util.h" #include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" @@ -65,17 +66,17 @@ Status DataServiceMasterImpl::RegisterWorker( // Allocate tasks to the worker. for (auto& entry : jobs_) { - Job& job = entry.second; - if (job.finished()) { + std::shared_ptr job = entry.second; + if (job->finished()) { continue; } - int64 task_id = CreateTask(&job, request->worker_address()); + int64 task_id = CreateTask(job.get(), request->worker_address()); TaskDef* task_def = response->add_tasks(); *task_def->mutable_dataset() = - datasets_by_id_[job.dataset_id()]->dataset_def(); - task_def->set_dataset_id(job.dataset_id()); - task_def->set_job_id(job.job_id()); + datasets_by_id_[job->dataset_id()]->dataset_def(); + task_def->set_dataset_id(job->dataset_id()); + task_def->set_job_id(job->job_id()); task_def->set_task_id(task_id); } @@ -96,7 +97,7 @@ Status DataServiceMasterImpl::WorkerUpdate(const WorkerUpdateRequest* request, if (update.completed()) { int64 job_id = tasks_.at(task_id).job_id(); DCHECK(jobs_.contains(job_id)); - jobs_.at(job_id).task_finished(task_id); + jobs_.at(job_id)->task_finished(task_id); VLOG(3) << "Task " << task_id << " from job " << job_id << " completed"; } } @@ -135,49 +136,109 @@ int64 DataServiceMasterImpl::RegisterDataset(uint64 fingerprint, DCHECK(!datasets_by_id_.contains(dataset_id)); datasets_by_id_[dataset_id] = new_dataset; DCHECK(!datasets_by_fingerprint_.contains(fingerprint)); - datasets_by_fingerprint_[dataset_id] = new_dataset; + datasets_by_fingerprint_[fingerprint] = new_dataset; return dataset_id; } Status DataServiceMasterImpl::CreateJob(const CreateJobRequest* request, CreateJobResponse* response) { - VLOG(3) << "Received begin job request for dataset id " + VLOG(3) << "Received create job request for dataset id " << request->dataset_id(); - switch (request->processing_mode()) { - case PARALLEL_EPOCHS: + ProcessingMode processing_mode = ProcessingMode(request->processing_mode()); + mutex_lock l(mu_); + int64 job_id; + TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), processing_mode, + absl::optional(), &job_id)); + response->set_job_id(job_id); + + VLOG(3) << "Creating job " << job_id << " for dataset " + << request->dataset_id(); + return Status::OK(); +} + +Status DataServiceMasterImpl::GetOrCreateJob( + const GetOrCreateJobRequest* request, GetOrCreateJobResponse* response) { + VLOG(3) << "Received get or create job request for dataset id " + << request->dataset_id() << " with name " << request->job_name() + << " and index " << request->job_name_index(); + mutex_lock l(mu_); + NamedJobKey key(request->job_name(), request->job_name_index()); + ProcessingMode requested_processing_mode = + ProcessingMode(request->processing_mode()); + std::shared_ptr* job = gtl::FindOrNull(named_jobs_, key); + if (job != nullptr) { + TF_RETURN_IF_ERROR(ValidateMatchingJob(**job, requested_processing_mode, + request->dataset_id())); + response->set_job_id((*job)->job_id()); + return Status::OK(); + } + int64 job_id; + TF_RETURN_IF_ERROR(CreateJob(request->dataset_id(), requested_processing_mode, + request->job_name(), &job_id)); + named_jobs_[key] = jobs_[job_id]; + response->set_job_id(job_id); + return Status::OK(); +} + +// Validates that the job matches the given processing_mode and dataset_id. +Status DataServiceMasterImpl::ValidateMatchingJob( + const Job& job, ProcessingMode processing_mode, int64 dataset_id) { + DCHECK(job.name().has_value()); + std::string job_name = job.name().value(); + if (job.processing_mode() != processing_mode) { + std::string requested = ProcessingModeToString(processing_mode); + std::string actual = ProcessingModeToString(job.processing_mode()); + return errors::FailedPrecondition( + "Found a job with name ", job_name, ", but the processing mode <", + actual, "> doesn't match the requested processing mode <", requested, + ">."); + } + if (job.dataset_id() != dataset_id) { + return errors::FailedPrecondition( + "Found a job with name ", job_name, ", but the dataset id <", + job.dataset_id(), "> doesn't match the requested dataset id <", + dataset_id, ">."); + } + return Status::OK(); +} + +Status DataServiceMasterImpl::CreateJob(int64 dataset_id, + ProcessingMode processing_mode, + absl::optional job_name, + int64* out_job_id) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + switch (processing_mode) { + case ProcessingMode::PARALLEL_EPOCHS: break; - case ONE_EPOCH: + case ProcessingMode::ONE_EPOCH: return errors::Unimplemented( "CreateJob only supports the PARALLEL_EPOCHS job mode. " "ONE_EPOCH is not currently supported."); default: - return errors::Unimplemented( - "ProcessingMode ", request->processing_mode(), " not recognized"); + return errors::Unimplemented("ProcessingMode ", + ProcessingModeToString(processing_mode), + " not recognized"); } - mutex_lock l(mu_); - if (!datasets_by_id_.contains(request->dataset_id())) { - return errors::NotFound("CreateJob failed. Dataset id: <", - request->dataset_id(), "> not found."); + if (!datasets_by_id_.contains(dataset_id)) { + return errors::NotFound("Dataset id: <", dataset_id, "> not found."); } int64 job_id = next_job_id_++; DCHECK(!jobs_.contains(job_id)); - auto result = - jobs_.emplace(std::piecewise_construct, std::forward_as_tuple(job_id), - std::forward_as_tuple(job_id, request->dataset_id())); - DCHECK(result.second); - Job& job = result.first->second; - response->set_job_id(job_id); + auto job = + std::make_shared(job_id, dataset_id, processing_mode, job_name); + jobs_[job_id] = job; for (auto& worker : workers_) { - int64 task_id = CreateTask(&job, worker.address()); + int64 task_id = CreateTask(job.get(), worker.address()); // TODO(aaudibert): perform these calls asynchronously. + // TODO(aaudibert): clean up in case some calls succeed, but later calls + // fail TF_RETURN_IF_ERROR(AllocateTaskToWorker(tasks_.at(task_id), &worker)); } - VLOG(3) << "Beginning job " << job_id << " for dataset " - << request->dataset_id(); + *out_job_id = job_id; return Status::OK(); } @@ -233,8 +294,8 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, return errors::NotFound("GetTasks failed. Job id <", request->job_id(), "> not found."); } - Job& job = it->second; - for (const auto& task_id : job.task_ids()) { + std::shared_ptr job = it->second; + for (const auto& task_id : job->task_ids()) { auto task_iter = tasks_.find(task_id); DCHECK(task_iter != tasks_.end()); Task& task = task_iter->second; @@ -242,7 +303,7 @@ Status DataServiceMasterImpl::GetTasks(const GetTasksRequest* request, task_info->set_worker_address(task.worker_address()); task_info->set_id(task.task_id()); } - response->set_job_finished(job.finished()); + response->set_job_finished(job->finished()); VLOG(3) << "Found " << response->task_info_size() << " tasks for job id " << request->job_id(); return Status::OK(); diff --git a/tensorflow/core/data/service/master_impl.h b/tensorflow/core/data/service/master_impl.h index b7cfc496e69..de25ea0d6a8 100644 --- a/tensorflow/core/data/service/master_impl.h +++ b/tensorflow/core/data/service/master_impl.h @@ -18,6 +18,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "tensorflow/core/data/service/common.pb.h" +#include "tensorflow/core/data/service/data_service.h" #include "tensorflow/core/data/service/master.pb.h" #include "tensorflow/core/data/service/worker.grpc.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -56,6 +57,8 @@ class DataServiceMasterImpl { GetOrRegisterDatasetResponse* response); Status CreateJob(const CreateJobRequest* request, CreateJobResponse* response); + Status GetOrCreateJob(const GetOrCreateJobRequest* request, + GetOrCreateJobResponse* response); Status GetTasks(const GetTasksRequest* request, GetTasksResponse* response); private: @@ -100,11 +103,17 @@ class DataServiceMasterImpl { class Job { public: - Job(int64 job_id, int64 dataset_id) - : job_id_(job_id), dataset_id_(dataset_id) {} + Job(int64 job_id, int64 dataset_id, ProcessingMode processing_mode, + absl::optional job_name) + : job_id_(job_id), + dataset_id_(dataset_id), + processing_mode_(processing_mode), + job_name_(job_name) {} int64 job_id() const { return job_id_; } int64 dataset_id() const { return dataset_id_; } + ProcessingMode processing_mode() const { return processing_mode_; } + absl::optional name() const { return job_name_; } const std::vector& task_ids() const { return task_ids_; } void add_task_id(int64 task_id) { task_ids_.push_back(task_id); } void task_finished(int64 task_id) { @@ -118,11 +127,32 @@ class DataServiceMasterImpl { private: const int64 job_id_; const int64 dataset_id_; + const ProcessingMode processing_mode_; + const absl::optional job_name_; std::vector task_ids_; std::vector finished_tasks_; bool finished_ = false; }; + class NamedJobKey { + public: + NamedJobKey(absl::string_view name, int64 index) + : name_(name), index_(index) {} + + friend bool operator==(const NamedJobKey& lhs, const NamedJobKey& rhs) { + return lhs.name_ == rhs.name_ && lhs.index_ == rhs.index_; + } + + template + friend H AbslHashValue(H h, const NamedJobKey& k) { + return H::combine(std::move(h), k.name_, k.index_); + } + + private: + const std::string name_; + const int64 index_; + }; + class Task { public: Task(int64 task_id, int64 job_id, int64 dataset_id, @@ -150,9 +180,15 @@ class DataServiceMasterImpl { Status EnsureWorkerStubInitialized(Worker* worker); // Instructs a worker to begin processing a task. Status AllocateTaskToWorker(const Task& task_id, Worker* worker); + // Creates a job and stores its job_id in `*job_id`. + Status CreateJob(int64 dataset_id, ProcessingMode processing_mode, + absl::optional job_name, int64* out_job_id); // Creates a new task for a job, returning the new task's id. int64 CreateTask(Job* job, const std::string& worker_address); - + // Validates that an existing job matches the given processing_mode and + // dataset_id, returning an error status describing any difference. + Status ValidateMatchingJob(const Job& job, ProcessingMode processing_mode, + int64 dataset_id); // Protocol to use for communicating with workers. const std::string protocol_; @@ -172,9 +208,13 @@ class DataServiceMasterImpl { absl::flat_hash_map> datasets_by_fingerprint_ TF_GUARDED_BY(mu_); // Information about jobs, keyed by job ids. - absl::flat_hash_map jobs_ TF_GUARDED_BY(mu_); + absl::flat_hash_map> jobs_ TF_GUARDED_BY(mu_); // Information about tasks, keyed by task ids. absl::flat_hash_map tasks_ TF_GUARDED_BY(mu_); + // Named jobs, keyed by their names and indices. Not all jobs have names, so + // this is a subset of the jobs stored in `jobs_`. + absl::flat_hash_map> named_jobs_ + TF_GUARDED_BY(mu_); TF_DISALLOW_COPY_AND_ASSIGN(DataServiceMasterImpl); }; diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc index ba5c3a54871..8c336686deb 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.cc @@ -47,8 +47,11 @@ namespace data { /* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode; /* static */ constexpr const char* const DataServiceDatasetOp::kAddress; /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol; +/* static */ constexpr const char* const DataServiceDatasetOp::kJobName; /* static */ constexpr const char* const DataServiceDatasetOp::kMaxOutstandingRequests; +/* static */ constexpr const char* const + DataServiceDatasetOp::kIterationCounter; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes; @@ -71,23 +74,45 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { public: Dataset(OpKernelContext* ctx, int64 dataset_id, ProcessingMode processing_mode, const std::string& address, - const std::string& protocol, int64 max_outstanding_requests, - int64 task_refresh_interval_ms, const DataTypeVector& output_types, + const std::string& protocol, const std::string& job_name, + int64 max_outstanding_requests, int64 task_refresh_interval_ms, + IterationCounter* iteration_counter, bool owns_resource, + ResourceHandle iteration_counter_handle, + const DataTypeVector& output_types, const std::vector& output_shapes) : DatasetBase(DatasetContext(ctx)), dataset_id_(dataset_id), processing_mode_(processing_mode), address_(address), protocol_(protocol), + job_name_(job_name), max_outstanding_requests_(max_outstanding_requests), task_refresh_interval_ms_(task_refresh_interval_ms), + iteration_counter_(iteration_counter), + owns_resource_(owns_resource), + iteration_counter_handle_(iteration_counter_handle), + resource_mgr_(ctx->resource_manager()), output_types_(output_types), output_shapes_(output_shapes) {} + ~Dataset() override { + iteration_counter_->Unref(); + if (owns_resource_) { + Status s = resource_mgr_->Delete( + iteration_counter_handle_.container(), + iteration_counter_handle_.name()); + if (!s.ok()) { + LOG(WARNING) << "Failed to delete iteration counter resource: " << s; + } + } + } + std::unique_ptr MakeIteratorInternal( const string& prefix) const override { - return absl::make_unique(Iterator::Params{ - this, name_utils::IteratorPrefix(kDatasetType, prefix)}); + return absl::make_unique( + Iterator::Params{this, + name_utils::IteratorPrefix(kDatasetType, prefix)}, + iteration_counter_->GetAndIncrement()); } const DataTypeVector& output_dtypes() const override { return output_types_; } @@ -123,18 +148,26 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { Node* protocol; TF_RETURN_IF_ERROR(b->AddScalar(protocol_, &protocol)); + Node* job_name; + TF_RETURN_IF_ERROR(b->AddScalar(job_name_, &job_name)); + Node* max_outstanding_requests; TF_RETURN_IF_ERROR( b->AddScalar(max_outstanding_requests_, &max_outstanding_requests)); + Node* iteration_counter_handle = nullptr; + Tensor handle(DT_RESOURCE, TensorShape({})); + handle.scalar()() = iteration_counter_handle_; + TF_RETURN_IF_ERROR(b->AddTensor(handle, &iteration_counter_handle)); + AttrValue task_refresh_interval_hint_ms; b->BuildAttrValue(task_refresh_interval_ms_, &task_refresh_interval_hint_ms); TF_RETURN_IF_ERROR( b->AddDataset(this, - {dataset_id, processing_mode, address, protocol, - max_outstanding_requests}, + {dataset_id, processing_mode, address, protocol, job_name, + max_outstanding_requests, iteration_counter_handle}, {std::make_pair(kTaskRefreshIntervalHintMs, task_refresh_interval_hint_ms)}, output)); @@ -144,8 +177,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { private: class Iterator : public DatasetIterator { public: - explicit Iterator(const Params& params) - : DatasetIterator(params) {} + explicit Iterator(const Params& params, int64 iterator_index) + : DatasetIterator(params), iterator_index_(iterator_index) {} ~Iterator() override { mutex_lock l(mu_); @@ -161,8 +194,14 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { VLOG(3) << "Connecting to " << dataset()->address_ << " in data service dataset op"; DataServiceMasterClient master(dataset()->address_, dataset()->protocol_); - TF_RETURN_IF_ERROR(master.CreateJob( - dataset()->dataset_id_, dataset()->processing_mode_, &job_id_)); + if (dataset()->job_name_.empty()) { + TF_RETURN_IF_ERROR(master.CreateJob( + dataset()->dataset_id_, dataset()->processing_mode_, &job_id_)); + } else { + TF_RETURN_IF_ERROR(master.GetOrCreateJob( + dataset()->dataset_id_, dataset()->processing_mode_, + dataset()->job_name_, iterator_index_, &job_id_)); + } VLOG(1) << "Created data service job with id " << job_id_; return Status::OK(); } @@ -418,6 +457,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { return Status::OK(); } + const int64 iterator_index_; + mutex mu_; // TODO(aaudibert): split this into a couple cvs for different conditions // so that we can use notify_one and avoid unnecessary wakeups. @@ -449,8 +490,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase { const ProcessingMode processing_mode_; const tstring address_; const tstring protocol_; + const tstring job_name_; const int64 max_outstanding_requests_; const int64 task_refresh_interval_ms_; + IterationCounter* const iteration_counter_; // Owned + const bool owns_resource_; + const ResourceHandle iteration_counter_handle_; + ResourceMgr* const resource_mgr_; // Not owned const DataTypeVector output_types_; const std::vector output_shapes_; }; @@ -488,9 +534,41 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, OP_REQUIRES(ctx, !protocol.empty(), errors::InvalidArgument(kProtocol, " must be non-empty.")); + tstring job_name; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kJobName, &job_name)); + int64 max_outstanding_requests; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kMaxOutstandingRequests, &max_outstanding_requests)); + + ResourceHandle iteration_counter_handle; + OP_REQUIRES_OK( + ctx, HandleFromInput(ctx, kIterationCounter, &iteration_counter_handle)); + IterationCounter* iteration_counter = nullptr; + Status s = ctx->resource_manager()->Lookup( + iteration_counter_handle.container(), iteration_counter_handle.name(), + &iteration_counter); + bool owns_resource = false; + if (errors::IsNotFound(s)) { + owns_resource = true; + static std::atomic resource_id_counter(0); + const std::string& container = ctx->resource_manager()->default_container(); + std::string name = + strings::StrCat(ctx->op_kernel().name(), "/", kIterationCounter, "_", + resource_id_counter.fetch_add(1)); + OP_REQUIRES_OK(ctx, + ctx->resource_manager()->LookupOrCreate( + container, name, &iteration_counter, + [](IterationCounter** counter) { + *counter = new IterationCounter(); + return Status::OK(); + })); + iteration_counter_handle = + MakeResourceHandle(ctx, container, name); + } else { + OP_REQUIRES_OK(ctx, s); + } + OP_REQUIRES( ctx, max_outstanding_requests == model::kAutotune || @@ -499,13 +577,16 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, model::kAutotune)); *output = - new Dataset(ctx, dataset_id, processing_mode, address, protocol, + new Dataset(ctx, dataset_id, processing_mode, address, protocol, job_name, max_outstanding_requests, task_refresh_interval_hint_ms_, + iteration_counter, owns_resource, iteration_counter_handle, output_types_, output_shapes_); } REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU), DataServiceDatasetOp); +REGISTER_KERNEL_BUILDER(Name("DummyIterationCounter").Device(DEVICE_CPU), + DummyResourceOp); } // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h index d64ca92bc64..b2c7f368c8e 100644 --- a/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h +++ b/tensorflow/core/kernels/data/experimental/data_service_dataset_op.h @@ -15,11 +15,34 @@ limitations under the License. #ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ #define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_DATASET_OP_H_ +#include "absl/strings/str_cat.h" #include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/resource_mgr.h" namespace tensorflow { namespace data { +// A resource which counts how many iterators have been created. This is used +// by the DataServiceDataset to coordinate jobs across multiple iterations. +class IterationCounter : public ResourceBase { + public: + IterationCounter() : counter_(0) {} + + std::string DebugString() const override { + mutex_lock l(mu_); + return absl::StrCat(counter_); + } + + int64 GetAndIncrement() { + mutex_lock l(mu_); + return ++counter_; + } + + private: + mutable mutex mu_; + int64 counter_ TF_GUARDED_BY(mu_) = 0; +}; + // Creates a dataset for reading from the tf.data service. class DataServiceDatasetOp : public DatasetOpKernel { public: @@ -28,10 +51,12 @@ class DataServiceDatasetOp : public DatasetOpKernel { static constexpr const char* const kProcessingMode = "processing_mode"; static constexpr const char* const kAddress = "address"; static constexpr const char* const kProtocol = "protocol"; + static constexpr const char* const kJobName = "job_name"; static constexpr const char* const kMaxOutstandingRequests = "max_outstanding_requests"; static constexpr const char* const kTaskRefreshIntervalHintMs = "task_refresh_interval_hint_ms"; + static constexpr const char* const kIterationCounter = "iteration_counter"; static constexpr const char* const kOutputTypes = "output_types"; static constexpr const char* const kOutputShapes = "output_shapes"; diff --git a/tensorflow/core/ops/compat/ops_history_v2/DataServiceDataset.pbtxt b/tensorflow/core/ops/compat/ops_history_v2/DataServiceDataset.pbtxt deleted file mode 100644 index 3c84defb799..00000000000 --- a/tensorflow/core/ops/compat/ops_history_v2/DataServiceDataset.pbtxt +++ /dev/null @@ -1,47 +0,0 @@ -op { - name: "DataServiceDataset" - input_arg { - name: "dataset_id" - type: DT_INT64 - } - input_arg { - name: "processing_mode" - type: DT_STRING - } - input_arg { - name: "address" - type: DT_STRING - } - input_arg { - name: "protocol" - type: DT_STRING - } - input_arg { - name: "max_outstanding_requests" - type: DT_INT64 - } - output_arg { - name: "handle" - type: DT_VARIANT - } - attr { - name: "task_refresh_interval_hint_ms" - type: "int" - default_value { - i: -1 - } - } - attr { - name: "output_types" - type: "list(type)" - has_minimum: true - minimum: 1 - } - attr { - name: "output_shapes" - type: "list(shape)" - has_minimum: true - minimum: 1 - } - is_stateful: true -} diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc index 09ec273a844..2c9cbe2f416 100644 --- a/tensorflow/core/ops/experimental_dataset_ops.cc +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -1037,12 +1037,21 @@ REGISTER_OP("ExperimentalUniqueDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("DummyIterationCounter") + .Output("handle: resource") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + REGISTER_OP("DataServiceDataset") .Input("dataset_id: int64") .Input("processing_mode: string") .Input("address: string") .Input("protocol: string") + .Input("job_name: string") .Input("max_outstanding_requests: int64") + .Input("iteration_counter: resource") .Output("handle: variant") .Attr("task_refresh_interval_hint_ms: int = -1") .Attr("output_types: list(type) >= 1") diff --git a/tensorflow/python/data/experimental/ops/data_service_ops.py b/tensorflow/python/data/experimental/ops/data_service_ops.py index b5b54e3e94a..c1c23668db0 100644 --- a/tensorflow/python/data/experimental/ops/data_service_ops.py +++ b/tensorflow/python/data/experimental/ops/data_service_ops.py @@ -51,6 +51,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): processing_mode, address, protocol, + job_name=None, max_outstanding_requests=None, task_refresh_interval_hint_ms=None): """Constructs a _DataServiceDatasetV2. @@ -65,6 +66,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): address: The tf.data service address, e.g. "localhost:5000". protocol: The protocol to use for communicating with the tf.data service, e.g. "grpc". + job_name: (Optional.) The name of the job. This argument makes it + possible for multiple datasets to share the same job. The default + behavior is that the dataset creates anonymous, exclusively owned jobs. max_outstanding_requests: (Optional.) A limit on how many elements may be requested at the same time. You can use this option to control the amount of memory used, since `distribute` won't use more than @@ -73,6 +77,8 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): the master for task changes. """ + if job_name is None: + job_name = "" if max_outstanding_requests is None: max_outstanding_requests = dataset_ops.AUTOTUNE if task_refresh_interval_hint_ms is None: @@ -85,8 +91,11 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource): processing_mode=processing_mode, address=address, protocol=protocol, + job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms, + iteration_counter=gen_experimental_dataset_ops.dummy_iteration_counter( + ), **self._flat_structure) super(_DataServiceDatasetV2, self).__init__(variant_tensor) @@ -100,7 +109,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): @functools.wraps(_DataServiceDatasetV2.__init__) def __init__(self, input_dataset, dataset_id, processing_mode, address, - protocol, max_outstanding_requests, + protocol, job_name, max_outstanding_requests, task_refresh_interval_hint_ms): self._wrapped = _DataServiceDatasetV2( @@ -109,6 +118,7 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter): processing_mode=processing_mode, address=address, protocol=protocol, + job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) super(_DataServiceDatasetV1, self).__init__(self._wrapped) @@ -122,6 +132,7 @@ else: def _distribute(processing_mode, service, + job_name=None, max_outstanding_requests=None, task_refresh_interval_hint_ms=None): """A transformation that moves dataset processing to the tf.data service. @@ -136,6 +147,9 @@ def _distribute(processing_mode, service: A string indicating how to connect to the tf.data service. The string should be in the format ://
, e.g. grpc://localhost:5000. + job_name: (Optional.) The name of the job. This argument makes it + possible for multiple datasets to share the same job. The default behavior + is that the dataset creates anonymous, exclusively owned jobs. max_outstanding_requests: (Optional.) A limit on how many elements may be requested at the same time. You can use this option to control the amount of memory used, since `distribute` won't use more than `element_size` * @@ -147,6 +161,12 @@ def _distribute(processing_mode, Dataset: A `Dataset` of the elements produced by the data service. """ ProcessingMode.validate(processing_mode) + if job_name is not None: + if not isinstance(job_name, six.string_types): + raise ValueError("job_name must be a string, but job_name was of type " + "{0}. job_name={1}".format(type(job_name), job_name)) + if not job_name: + raise ValueError("job_name must not be empty") if not isinstance(service, six.string_types): raise ValueError( "service must be a string, but service was of type {0}. service={1}" @@ -182,41 +202,49 @@ def _distribute(processing_mode, processing_mode=processing_mode, address=address, protocol=protocol, + job_name=job_name, max_outstanding_requests=max_outstanding_requests, task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) return _apply_fn -def distribute(processing_mode, service, max_outstanding_requests=None): +def distribute(processing_mode, + service, + job_name=None, + max_outstanding_requests=None): """A transformation that moves dataset processing to the tf.data service. - The `processing_mode` argument controls how data is processed by the - tf.data service. Currently, the only supported mode is "parallel_epochs". + When you iterate over a dataset containing the `distribute` transformation, + the tf.data service creates a "job" which produces data for the dataset + iteration. + + The `processing_mode` argument controls what data is produced by a tf.data + service job. Currently, the only supported mode is "parallel_epochs". processing_mode="parallel_epochs" means that multiple tf.data workers will iterate through the dataset in parallel, each producing all elements of the dataset. For example, if the dataset contains {0, 1, 2}, every tf.data worker - used for execution will produce {0, 1, 2}. If there are 3 workers and one - consumer, the consumer will receive the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} - (though not necessarily in that order). To account for this, it is recommended - to randomly shuffle your dataset, so that different tf.data workers will - iterate through the dataset in different orders. + used for execution will produce {0, 1, 2}. If there are 3 workers, the job + will produce the elements {0, 0, 0, 1, 1, 1, 2, 2, 2} (though not necessarily + in that order). To account for this, it is recommended to randomly shuffle + your dataset, so that different tf.data workers will iterate through the + dataset in different orders. - In the future, there will be additional epoch modes. For example, + In the future, there will be additional processing modes. For example, a "one_epoch" mode which partitions the dataset across the tf.data workers, so that the consumers see each element of the dataset only once. ``` - dataset = tf.data.Dataset.range(10) + dataset = tf.data.Dataset.range(5) dataset = dataset.map(lambda x: x*x) dataset = dataset.apply( tf.data.experimental.service.distribute("parallel_epochs", "grpc://dataservice:5000")) - dataset = dataset.map(lambda x: x+10) + dataset = dataset.map(lambda x: x+1) for element in dataset: - # process element + print(element) # prints { 1, 2, 5, 10, 17 } ``` In the above example, the first two lines (before the call to `distribute`) @@ -224,6 +252,33 @@ def distribute(processing_mode, service, max_outstanding_requests=None): RPC. The remaining transformations (after the call to `distribute`) will be executed locally. + The `job_name` argument allows jobs to be shared across multiple + datasets. Instead of each dataset creating its own job, all datasets with the + same `job_name` will consume from the same job. A new job will + be created for each iteration of the dataset (with each repetition of + `Dataset.repeat` counting as a new iteration). The following example + demonstrates shared iteration, with the assumption that the tf.data service is + running with a single worker. + + ``` + range5_dataset = tf.data.Dataset.range(5) + dataset1 = range5_dataset.apply(tf.data.experimental.service.distribute( + "parallel_epochs", "my_job_name", "grpc://dataservice:5000")) + dataset2 = range5_dataset.apply(tf.data.experimental.service.distribute( + "parallel_epochs", "my_job_name", "grpc://dataservice:5000")) + iter_1_1 = iter(dataset1) + iter_1_2 = iter(dataset1) + iter_2_1 = iter(dataset2) + iter_2_2 = iter(dataset2) + print(next(iter_1_1)) # Prints "0" + # iter_1_2 consumes from the same job as iter_1_1 + print(next(iter_1_2)) # Prints "1" + # iter_2_1 consumes from a new job + print(next(iter_2_1)) # Prints "0" + # iter_2_2 consumes from the same job as iter_2_1 + print(next(iter_2_2)) # Prints "1" + ``` + Args: processing_mode: A string specifying the policy for how data should be processed by tf.data workers. Currently, the only supported value is @@ -231,6 +286,9 @@ def distribute(processing_mode, service, max_outstanding_requests=None): service: A string indicating how to connect to the tf.data service. The string should be in the format ://
, e.g. grpc://localhost:5000. + job_name: (Optional.) The name of the job. This argument makes it possible + for multiple datasets to share the same job. The default behavior is that + the dataset creates anonymous, exclusively owned jobs. max_outstanding_requests: (Optional.) A limit on how many elements may be requested at the same time. You can use this option to control the amount of memory used, since `distribute` won't use more than `element_size` * @@ -239,4 +297,5 @@ def distribute(processing_mode, service, max_outstanding_requests=None): Returns: Dataset: A `Dataset` of the elements produced by the data service. """ - return _distribute(processing_mode, service, max_outstanding_requests) + return _distribute(processing_mode, service, job_name, + max_outstanding_requests) diff --git a/tensorflow/python/data/kernel_tests/data_service_ops_test.py b/tensorflow/python/data/kernel_tests/data_service_ops_test.py index b6e963959e4..eac1c674b2d 100644 --- a/tensorflow/python/data/kernel_tests/data_service_ops_test.py +++ b/tensorflow/python/data/kernel_tests/data_service_ops_test.py @@ -37,11 +37,14 @@ from tensorflow.python.platform import test PROTOCOL = "grpc" -def _make_distributed_dataset(dataset, service): +def _make_distributed_dataset(dataset, service, job_name=None): """Creates a distributed dataset with a short task refresh interval.""" return dataset.apply( data_service_ops._distribute( - "parallel_epochs", service, task_refresh_interval_hint_ms=20)) + "parallel_epochs", + service, + job_name=job_name, + task_refresh_interval_hint_ms=20)) class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): @@ -233,6 +236,72 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): result = list(f().numpy()) self.assertCountEqual(num_workers * list(range(num_elements)), result) + @combinations.generate(test_base.eager_only_combinations()) + def testSharedJobName(self): + num_elements = 10 + service = self.create_cluster(1) + ds = dataset_ops.Dataset.range(num_elements) + ds1 = _make_distributed_dataset(ds, service, job_name="job_name") + ds2 = _make_distributed_dataset(ds, service, job_name="job_name") + iter1 = iter(ds1) + iter2 = iter(ds2) + results = [] + for _ in range(3): + results.append(next(iter1).numpy()) + results.append(next(iter2).numpy()) + for elem in iter1: + results.append(elem.numpy()) + for elem in iter2: + results.append(elem.numpy()) + self.assertCountEqual(list(range(num_elements)), results) + + @combinations.generate(test_base.eager_only_combinations()) + def testDifferentJobNames(self): + num_elements = 10 + service = self.create_cluster(1) + ds = dataset_ops.Dataset.range(num_elements) + ds1 = _make_distributed_dataset(ds, service, job_name="job_name1") + ds2 = _make_distributed_dataset(ds, service, job_name="job_name2") + self.assertDatasetProduces(ds1, list(range(num_elements))) + self.assertDatasetProduces(ds2, list(range(num_elements))) + + @combinations.generate(test_base.eager_only_combinations()) + def testSharedJobNameMultiIteration(self): + num_elements = 10 + service = self.create_cluster(1) + ds = dataset_ops.Dataset.range(num_elements) + ds1 = _make_distributed_dataset(ds, service, job_name="job_name") + ds2 = _make_distributed_dataset(ds, service, job_name="job_name") + # iteration 1 + self.assertDatasetProduces(ds1, list(range(num_elements))) + self.assertDatasetProduces(ds2, []) + # iteration 2 + self.assertDatasetProduces(ds2, list(range(num_elements))) + self.assertDatasetProduces(ds1, []) + + @combinations.generate(test_base.eager_only_combinations()) + def testSharedJobNameRepeat(self): + num_elements = 10 + num_repetitions = 3 + service = self.create_cluster(1) + ds = dataset_ops.Dataset.range(num_elements) + ds1 = _make_distributed_dataset(ds, service, job_name="job_name") + ds1 = ds1.repeat(num_repetitions) + ds2 = _make_distributed_dataset(ds, service, job_name="job_name") + ds2 = ds2.repeat(num_repetitions) + results = [] + iter1 = iter(ds1) + iter2 = iter(ds2) + for _ in range(((num_elements * num_repetitions) // 2) - 1): + results.append(next(iter1).numpy()) + for _ in range(((num_elements * num_repetitions) // 2) - 1): + results.append(next(iter2).numpy()) + for elem in iter1: + results.append(elem.numpy()) + for elem in iter2: + results.append(elem.numpy()) + self.assertCountEqual(num_repetitions * list(range(num_elements)), results) + def run_stateful(self, external_state_policy): num_elements = 10 ds = dataset_ops.Dataset.range(num_elements).map( diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 9a074786690..061db56fd19 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -934,7 +934,7 @@ tf_module { } member_method { name: "DataServiceDataset" - argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " } member_method { name: "DatasetCardinality" @@ -1176,6 +1176,10 @@ tf_module { name: "DrawBoundingBoxesV2" argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DummyIterationCounter" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DummyMemoryCache" argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 9a074786690..061db56fd19 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -934,7 +934,7 @@ tf_module { } member_method { name: "DataServiceDataset" - argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " + argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'job_name\', \'max_outstanding_requests\', \'iteration_counter\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " } member_method { name: "DatasetCardinality" @@ -1176,6 +1176,10 @@ tf_module { name: "DrawBoundingBoxesV2" argspec: "args=[\'images\', \'boxes\', \'colors\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "DummyIterationCounter" + argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], " + } member_method { name: "DummyMemoryCache" argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'None\'], "