[tf.data service] Support __iter__ for tf.data service datasets.

Now we can iterate through tf.data service datasets with `for elem in dataset`, and `distributed_dataset.repeat()` will work correctly.

This CL removes the previous method of iteration via CreateJob/CreateDataServiceIterator. It wasn't yet made public, so it is OK to remove the old ops.

PiperOrigin-RevId: 309281077
Change-Id: I9531f7d2834ce6669f15896d8c830d23d8277b13
This commit is contained in:
Andrew Audibert 2020-04-30 12:48:16 -07:00 committed by TensorFlower Gardener
parent 2d26525724
commit 655d6d3cca
23 changed files with 211 additions and 507 deletions

View File

@ -1,5 +0,0 @@
op {
graph_op_name: "CreateJob"
visibility: HIDDEN
summary: "Creates a tf.data service job."
}

View File

@ -1,5 +0,0 @@
op {
graph_op_name: "MakeDataServiceIterator"
visibility: HIDDEN
summary: "Creates an iterator for reading from the tf.data service."
}

View File

@ -26,6 +26,34 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace {
constexpr const char kParallelEpochs[] = "parallel_epochs";
constexpr const char kOneEpoch[] = "one_epoch";
} // namespace
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) {
if (s == kParallelEpochs) {
*mode = ProcessingMode::PARALLEL_EPOCHS;
} else if (s == kOneEpoch) {
*mode = ProcessingMode::ONE_EPOCH;
} else {
return errors::InvalidArgument("Unrecognized processing mode: ", s);
}
return Status::OK();
}
std::string ProcessingModeToString(ProcessingMode mode) {
switch (mode) {
case ProcessingMode::PARALLEL_EPOCHS:
return kParallelEpochs;
case ProcessingMode::ONE_EPOCH:
return kOneEpoch;
default:
DCHECK(false);
return "Unknown";
}
}
Status DataServiceMasterClient::CreateJob(int64 dataset_id, Status DataServiceMasterClient::CreateJob(int64 dataset_id,
ProcessingMode processing_mode, ProcessingMode processing_mode,
int64* job_id) { int64* job_id) {

View File

@ -33,6 +33,13 @@ enum class ProcessingMode : int64 {
ONE_EPOCH = 1, ONE_EPOCH = 1,
}; };
// Parses a string representing a processing mode and stores the result in
// *mode. Returns an InvalidArgument status if the string is not recognized.
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode);
// Converts a processing mode to its corresponding string.
std::string ProcessingModeToString(ProcessingMode mode);
// Base class for data service clients. Data service clients are // Base class for data service clients. Data service clients are
// thread-compatible, requiring external synchronization when used from multiple // thread-compatible, requiring external synchronization when used from multiple
// threads. // threads.

View File

@ -38,6 +38,30 @@ namespace data {
namespace { namespace {
constexpr const char kProtocol[] = "grpc+local"; constexpr const char kProtocol[] = "grpc+local";
TEST(DataService, ParseParallelEpochsProcessingMode) {
ProcessingMode mode;
TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", &mode));
EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS);
}
TEST(DataService, ParseOneEpochProcessingMode) {
ProcessingMode mode;
TF_ASSERT_OK(ParseProcessingMode("one_epoch", &mode));
EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH);
}
TEST(DataService, ParseInvalidProcessingMode) {
ProcessingMode mode;
Status s = ParseProcessingMode("invalid", &mode);
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
}
TEST(DataService, ProcessingModeToString) {
EXPECT_EQ("parallel_epochs",
ProcessingModeToString(ProcessingMode::PARALLEL_EPOCHS));
EXPECT_EQ("one_epoch", ProcessingModeToString(ProcessingMode::ONE_EPOCH));
}
Status CheckWorkerOutput(const std::string& worker_address, int64 task_id, Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
std::vector<std::vector<Tensor>> expected_output) { std::vector<std::vector<Tensor>> expected_output) {
DataServiceWorkerClient worker(worker_address, kProtocol); DataServiceWorkerClient worker(worker_address, kProtocol);

View File

@ -117,11 +117,13 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
} }
std::unique_ptr<standalone::Iterator>& iter = it->second.iterator; std::unique_ptr<standalone::Iterator>& iter = it->second.iterator;
if (iter == nullptr) { if (iter == nullptr) {
VLOG(3) << "Task " << request->task_id() << " is already finished";
response->set_end_of_sequence(true); response->set_end_of_sequence(true);
return Status::OK(); return Status::OK();
} }
TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence)); TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence));
if (end_of_sequence) { if (end_of_sequence) {
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
// Release iterator memory and leave a null entry as a tombstone. // Release iterator memory and leave a null entry as a tombstone.
iter.reset(); iter.reset();
pending_completed_tasks_.push_back(request->task_id()); pending_completed_tasks_.push_back(request->task_id());
@ -130,6 +132,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
} }
if (!end_of_sequence) { if (!end_of_sequence) {
VLOG(3) << "Producing an element for task " << request->task_id();
TF_RETURN_IF_ERROR(service_util::Compress( TF_RETURN_IF_ERROR(service_util::Compress(
outputs, response->mutable_compressed_element())); outputs, response->mutable_compressed_element()));
} }

View File

@ -292,36 +292,6 @@ class Runner {
static Runner* get(); static Runner* get();
}; };
// A token for reading from a tf.data service job.
class JobToken {
public:
JobToken() : is_empty_(true) {}
explicit JobToken(int64 job_id) : job_id_(job_id), is_empty_(false) {}
bool is_empty() const { return is_empty_; }
int64 job_id() const { return job_id_; }
string TypeName() const { return "tensorflow::JobToken"; }
void Encode(VariantTensorData* data) const {
Tensor job_id = Tensor(DT_INT64, TensorShape({}));
job_id.scalar<int64>()() = job_id_;
*(data->add_tensors()) = job_id;
Tensor is_empty = Tensor(DT_BOOL, TensorShape({}));
is_empty.scalar<bool>()() = is_empty_;
*(data->add_tensors()) = is_empty;
}
bool Decode(const VariantTensorData& data) {
job_id_ = data.tensors(0).scalar<int64>()();
is_empty_ = data.tensors(1).scalar<bool>()();
return true;
}
private:
int64 job_id_;
bool is_empty_;
};
// A cut-down version of `OpKernelContext` for running computations in // A cut-down version of `OpKernelContext` for running computations in
// iterators. Note that we cannot simply use `OpKernelContext` here because we // iterators. Note that we cannot simply use `OpKernelContext` here because we
// might run computation in an iterator whose lifetime is not nested within the // might run computation in an iterator whose lifetime is not nested within the
@ -342,7 +312,6 @@ class IteratorContext {
env(ctx->env()), env(ctx->env()),
flr(ctx->flr()), flr(ctx->flr()),
function_handle_cache(ctx->function_handle_cache()), function_handle_cache(ctx->function_handle_cache()),
job_token(ctx->job_token()),
resource_mgr(ctx->resource_mgr()), resource_mgr(ctx->resource_mgr()),
model(ctx->model()), model(ctx->model()),
runner(*(ctx->runner())), runner(*(ctx->runner())),
@ -401,9 +370,6 @@ class IteratorContext {
// A FunctionHandleCache that owns all the function handles. Not owned. // A FunctionHandleCache that owns all the function handles. Not owned.
FunctionHandleCache* function_handle_cache = nullptr; FunctionHandleCache* function_handle_cache = nullptr;
// A token for reading data from a tf.data service job.
JobToken job_token;
// A resource manager for storing dataset-related state, e.g. random // A resource manager for storing dataset-related state, e.g. random
// seeds or cached tensors. Not owned. // seeds or cached tensors. Not owned.
ResourceMgr* resource_mgr = nullptr; ResourceMgr* resource_mgr = nullptr;
@ -453,8 +419,6 @@ class IteratorContext {
return params_.function_handle_cache; return params_.function_handle_cache;
} }
const JobToken& job_token() { return params_.job_token; }
ResourceMgr* resource_mgr() { return params_.resource_mgr; } ResourceMgr* resource_mgr() { return params_.resource_mgr; }
const std::shared_ptr<model::Model>& model() { return params_.model; } const std::shared_ptr<model::Model>& model() { return params_.model; }

View File

@ -91,27 +91,4 @@ INSTANTIATE_TEST_SUITE_P(
{_tf_string_, tensor_strs, {_tf_string_, tensor_strs,
static_cast<int64>(sizeof(str) + str.size()) /*bytes*/}})); static_cast<int64>(sizeof(str) + str.size()) /*bytes*/}}));
TEST(DatasetTest, JobServiceTokenIsEmpty) {
data::JobToken token;
EXPECT_TRUE(token.is_empty());
}
TEST(DatasetTest, JobTokenHoldsJobId) {
int64 job_id = 5;
data::JobToken token(job_id);
EXPECT_EQ(job_id, token.job_id());
EXPECT_FALSE(token.is_empty());
}
TEST(DatasetTest, JobTokenEncodeDecode) {
int64 job_id = 5;
data::JobToken token(job_id);
VariantTensorData data;
token.Encode(&data);
data::JobToken decoded;
decoded.Decode(data);
EXPECT_FALSE(token.is_empty());
EXPECT_EQ(job_id, token.job_id());
}
} // namespace tensorflow } // namespace tensorflow

View File

@ -43,6 +43,8 @@ namespace tensorflow {
namespace data { namespace data {
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType; /* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType;
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetId;
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress; /* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol; /* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
/* static */ constexpr const char* const /* static */ constexpr const char* const
@ -50,6 +52,7 @@ namespace data {
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes; /* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
namespace {
// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for // Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for
// the current attempt to complete and perform no more retries. // the current attempt to complete and perform no more retries.
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes. const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
@ -57,6 +60,8 @@ const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
// Default interval between task list refreshes. // Default interval between task list refreshes.
const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second. const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
} // namespace
// Dataset for reading data from the tf.data service non-deterministically. // Dataset for reading data from the tf.data service non-deterministically.
// //
// This dataset interleaves dataset elements produced by multiple tf.data // This dataset interleaves dataset elements produced by multiple tf.data
@ -64,11 +69,14 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
// to read from (in case workers are added or removed). // to read from (in case workers are added or removed).
class DataServiceDatasetOp::Dataset : public DatasetBase { class DataServiceDatasetOp::Dataset : public DatasetBase {
public: public:
Dataset(OpKernelContext* ctx, const std::string& address, Dataset(OpKernelContext* ctx, int64 dataset_id,
ProcessingMode processing_mode, const std::string& address,
const std::string& protocol, int64 max_outstanding_requests, const std::string& protocol, int64 max_outstanding_requests,
int64 task_refresh_interval_ms, const DataTypeVector& output_types, int64 task_refresh_interval_ms, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes) const std::vector<PartialTensorShape>& output_shapes)
: DatasetBase(DatasetContext(ctx)), : DatasetBase(DatasetContext(ctx)),
dataset_id_(dataset_id),
processing_mode_(processing_mode),
address_(address), address_(address),
protocol_(protocol), protocol_(protocol),
max_outstanding_requests_(max_outstanding_requests), max_outstanding_requests_(max_outstanding_requests),
@ -102,6 +110,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
Node** output) const override { Node** output) const override {
Node* dataset_id;
TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id));
Node* processing_mode;
tstring processing_mode_str = ProcessingModeToString(processing_mode_);
TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode));
Node* address; Node* address;
TF_RETURN_IF_ERROR(b->AddScalar(address_, &address)); TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
@ -117,7 +132,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
&task_refresh_interval_hint_ms); &task_refresh_interval_hint_ms);
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
b->AddDataset(this, {address, protocol, max_outstanding_requests}, b->AddDataset(this,
{dataset_id, processing_mode, address, protocol,
max_outstanding_requests},
{std::make_pair(kTaskRefreshIntervalHintMs, {std::make_pair(kTaskRefreshIntervalHintMs,
task_refresh_interval_hint_ms)}, task_refresh_interval_hint_ms)},
output)); output));
@ -132,6 +149,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
~Iterator() override { ~Iterator() override {
mutex_lock l(mu_); mutex_lock l(mu_);
VLOG(1) << "Destroying data service dataset iterator for job id "
<< job_id_;
cancelled_ = true; cancelled_ = true;
cv_.notify_all(); cv_.notify_all();
// Thread destructors will block until the threads finish, no need to wait // Thread destructors will block until the threads finish, no need to wait
@ -141,14 +160,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
Status Initialize(IteratorContext* ctx) override { Status Initialize(IteratorContext* ctx) override {
VLOG(3) << "Connecting to " << dataset()->address_ VLOG(3) << "Connecting to " << dataset()->address_
<< " in data service dataset op"; << " in data service dataset op";
if (ctx->job_token().is_empty()) { DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
return errors::FailedPrecondition( TF_RETURN_IF_ERROR(master.CreateJob(
"Expected a job token, but none found. To iterate over a dataset " dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
"containing a `distribute` transformation, call `create_job`, " VLOG(1) << "Created data service job with id " << job_id_;
"which will return a job token that you should then use to iterate "
"over the dataset via `create_iterator(dataset, job_token).`");
}
job_id_ = ctx->job_token().job_id();
return Status::OK(); return Status::OK();
} }
@ -175,6 +190,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
return Status::OK(); return Status::OK();
} }
DCHECK(!results_.empty()); DCHECK(!results_.empty());
*end_of_sequence = false;
out_tensors->swap(results_.front()); out_tensors->swap(results_.front());
results_.pop(); results_.pop();
cv_.notify_all(); cv_.notify_all();
@ -287,6 +303,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
continue; continue;
} }
if (!task_ids.contains(task_thread->task_id)) { if (!task_ids.contains(task_thread->task_id)) {
VLOG(3) << "Marking removed task thread " << task_thread->task_id
<< " as finished";
task_thread->end_of_sequence = true; task_thread->end_of_sequence = true;
} }
++it; ++it;
@ -315,6 +333,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
{ {
mutex_lock l(mu_); mutex_lock l(mu_);
if (task_handler->end_of_sequence) { if (task_handler->end_of_sequence) {
VLOG(3) << "Task thread " << task_handler->task_id
<< " reached end_of_sequence";
return; return;
} }
outstanding_requests_--; outstanding_requests_--;
@ -427,6 +447,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_); std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_);
}; };
const int64 dataset_id_;
const ProcessingMode processing_mode_;
const tstring address_; const tstring address_;
const tstring protocol_; const tstring protocol_;
const int64 max_outstanding_requests_; const int64 max_outstanding_requests_;
@ -448,6 +470,16 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx, void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) { DatasetBase** output) {
int64 dataset_id;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
tstring processing_mode_str;
OP_REQUIRES_OK(
ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
ProcessingMode processing_mode;
OP_REQUIRES_OK(ctx,
ParseProcessingMode(processing_mode_str, &processing_mode));
tstring address; tstring address;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address)); OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
OP_REQUIRES(ctx, !address.empty(), OP_REQUIRES(ctx, !address.empty(),
@ -468,9 +500,10 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ", errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
model::kAutotune)); model::kAutotune));
*output = new Dataset(ctx, address, protocol, max_outstanding_requests, *output =
task_refresh_interval_hint_ms_, output_types_, new Dataset(ctx, dataset_id, processing_mode, address, protocol,
output_shapes_); max_outstanding_requests, task_refresh_interval_hint_ms_,
output_types_, output_shapes_);
} }
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),

View File

@ -24,6 +24,8 @@ namespace data {
class DataServiceDatasetOp : public DatasetOpKernel { class DataServiceDatasetOp : public DatasetOpKernel {
public: public:
static constexpr const char* const kDatasetType = "DataService"; static constexpr const char* const kDatasetType = "DataService";
static constexpr const char* const kDatasetId = "dataset_id";
static constexpr const char* const kProcessingMode = "processing_mode";
static constexpr const char* const kAddress = "address"; static constexpr const char* const kAddress = "address";
static constexpr const char* const kProtocol = "protocol"; static constexpr const char* const kProtocol = "protocol";
static constexpr const char* const kMaxOutstandingRequests = static constexpr const char* const kMaxOutstandingRequests =

View File

@ -22,18 +22,6 @@ limitations under the License.
namespace tensorflow { namespace tensorflow {
namespace data { namespace data {
namespace {
Status ParseProcessingMode(const tstring& s, ProcessingMode* mode) {
if (s == "parallel_epochs") {
*mode = ProcessingMode::PARALLEL_EPOCHS;
} else if (s == "one_epoch") {
*mode = ProcessingMode::ONE_EPOCH;
} else {
return errors::InvalidArgument("Unrecognized processing mode: ", s);
}
return Status::OK();
}
} // namespace
RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx) RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx)
: OpKernel(ctx) { : OpKernel(ctx) {
@ -75,62 +63,8 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
output_dataset_id() = dataset_id; output_dataset_id() = dataset_id;
} }
CreateJobOp::CreateJobOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void CreateJobOp::Compute(OpKernelContext* ctx) {
int64 dataset_id;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
tstring address;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
OP_REQUIRES(ctx, !address.empty(),
errors::InvalidArgument(kAddress, " must be non-empty."));
tstring protocol;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
OP_REQUIRES(ctx, !protocol.empty(),
errors::InvalidArgument(kProtocol, " must be non-empty."));
tstring processing_mode_str;
OP_REQUIRES_OK(
ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
ProcessingMode processing_mode;
OP_REQUIRES_OK(ctx,
ParseProcessingMode(processing_mode_str, &processing_mode));
DataServiceMasterClient client(address, protocol);
int64 job_id;
OP_REQUIRES_OK(ctx, client.CreateJob(dataset_id, processing_mode, &job_id));
JobToken token(job_id);
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
auto output_token = output->tensor<Variant, 0>();
output_token() = token;
}
Status MakeDataServiceIteratorOp::DoCompute(OpKernelContext* ctx) {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
const Tensor* token_tensor;
TF_RETURN_IF_ERROR(ctx->input(kJobToken, &token_tensor));
JobToken token = *token_tensor->scalar<Variant>()().get<JobToken>();
IteratorResource* iterator_resource;
TF_RETURN_IF_ERROR(
LookupResource(ctx, HandleFromInput(ctx, 2), &iterator_resource));
core::ScopedUnref unref_iterator(iterator_resource);
return iterator_resource->SetIteratorFromDataset(ctx, dataset, token);
}
REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU),
RegisterDatasetOp); RegisterDatasetOp);
REGISTER_KERNEL_BUILDER(Name("CreateJob").Device(DEVICE_CPU), CreateJobOp);
REGISTER_KERNEL_BUILDER(Name("MakeDataServiceIterator").Device(DEVICE_CPU),
MakeDataServiceIteratorOp);
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow

View File

@ -44,41 +44,6 @@ class RegisterDatasetOp : public OpKernel {
SerializationContext::ExternalStatePolicy external_state_policy_; SerializationContext::ExternalStatePolicy external_state_policy_;
}; };
// Creates a token for reading from the tf.data service.
//
// The dataset_id input identifies which dataset to create a token for.
// The address and protocol inputs are used to connect to the tf.data service
// master.
// The processing_mode defines how the tf.data service should produce data for
// the token.
class CreateJobOp : public OpKernel {
public:
static constexpr const char* const kDatasetId = "dataset_id";
static constexpr const char* const kAddress = "address";
static constexpr const char* const kProtocol = "protocol";
static constexpr const char* const kProcessingMode = "processing_mode";
explicit CreateJobOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
};
// Creates a new iterator for iterating over a tf.data service dataset.
//
// The epoch_id input identifies which epoch to read from. Multiple iterators
// may read from the same epoch, causing the elements of the epoch to be split
// across all iterators.
class MakeDataServiceIteratorOp : public MakeIteratorOp {
public:
static constexpr const char* const kJobToken = "job_token";
explicit MakeDataServiceIteratorOp(OpKernelConstruction* ctx)
: MakeIteratorOp(ctx) {}
protected:
Status DoCompute(OpKernelContext* ctx) override;
};
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_ #endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_

View File

@ -166,8 +166,7 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
} }
Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx, Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
DatasetBase* dataset, DatasetBase* dataset) {
JobToken job_token) {
std::shared_ptr<State> new_state; std::shared_ptr<State> new_state;
{ {
tf_shared_lock l(mu_); tf_shared_lock l(mu_);
@ -180,7 +179,6 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
IteratorContext::Params params(ctx); IteratorContext::Params params(ctx);
params.flr = new_state->flr; params.flr = new_state->flr;
params.function_handle_cache = new_state->function_handle_cache.get(); params.function_handle_cache = new_state->function_handle_cache.get();
params.job_token = job_token;
params.resource_mgr = &new_state->resource_mgr; params.resource_mgr = &new_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory(); params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_; params.thread_pool = &unbounded_thread_pool_;
@ -532,9 +530,7 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource)); LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref_iterator(iterator_resource); core::ScopedUnref unref_iterator(iterator_resource);
JobToken empty_token; return iterator_resource->SetIteratorFromDataset(ctx, dataset);
return iterator_resource->SetIteratorFromDataset(ctx, dataset,
/*job_token=*/empty_token);
} }
void DeleteIteratorOp::Compute(OpKernelContext* ctx) { void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
@ -855,9 +851,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
// factory function. // factory function.
DatasetBase* dataset; DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset)); TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
JobToken empty_token; TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(
ctx, dataset, /*job_token=*/empty_token));
(*iterator)->Ref(); (*iterator)->Ref();
return Status::OK(); return Status::OK();
} }

View File

@ -69,14 +69,9 @@ class IteratorResource : public ResourceBase {
// Creates an iterator for `dataset`, and associates the iterator with this // Creates an iterator for `dataset`, and associates the iterator with this
// iterator resource. // iterator resource.
// //
// The `job_token` will be passed through the IteratorContext when
// creating the iterator. This token is used to read from a tf.data service
// job.
//
// `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`, // `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
// or `Restore`. // or `Restore`.
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset, Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset);
JobToken job_token);
string DebugString() const override { return "Iterator resource"; } string DebugString() const override { return "Iterator resource"; }

View File

@ -1,23 +0,0 @@
op {
name: "CreateJob"
input_arg {
name: "dataset_id"
type: DT_INT64
}
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "protocol"
type: DT_STRING
}
input_arg {
name: "processing_mode"
type: DT_STRING
}
output_arg {
name: "job_token"
type: DT_VARIANT
}
}

View File

@ -1,16 +0,0 @@
op {
name: "MakeDataServiceIterator"
input_arg {
name: "dataset"
type: DT_VARIANT
}
input_arg {
name: "job_token"
type: DT_VARIANT
}
input_arg {
name: "iterator"
type: DT_RESOURCE
}
is_stateful: true
}

View File

@ -1,23 +0,0 @@
op {
name: "CreateJob"
input_arg {
name: "dataset_id"
type: DT_INT64
}
input_arg {
name: "address"
type: DT_STRING
}
input_arg {
name: "protocol"
type: DT_STRING
}
input_arg {
name: "processing_mode"
type: DT_STRING
}
output_arg {
name: "job_token"
type: DT_VARIANT
}
}

View File

@ -1,16 +0,0 @@
op {
name: "MakeDataServiceIterator"
input_arg {
name: "dataset"
type: DT_VARIANT
}
input_arg {
name: "job_token"
type: DT_VARIANT
}
input_arg {
name: "iterator"
type: DT_RESOURCE
}
is_stateful: true
}

View File

@ -1043,6 +1043,8 @@ REGISTER_OP("ExperimentalUniqueDataset")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("DataServiceDataset") REGISTER_OP("DataServiceDataset")
.Input("dataset_id: int64")
.Input("processing_mode: string")
.Input("address: string") .Input("address: string")
.Input("protocol: string") .Input("protocol: string")
.Input("max_outstanding_requests: int64") .Input("max_outstanding_requests: int64")
@ -1061,18 +1063,4 @@ REGISTER_OP("RegisterDataset")
.Attr("external_state_policy: int") .Attr("external_state_policy: int")
.SetShapeFn(shape_inference::ScalarShape); .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("CreateJob")
.Input("dataset_id: int64")
.Input("address: string")
.Input("protocol: string")
.Input("processing_mode: string")
.Output("job_token: variant")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MakeDataServiceIterator")
.Input("dataset: variant")
.Input("job_token: variant")
.Input("iterator: resource")
.SetShapeFn(shape_inference::NoOutputs);
} // namespace tensorflow } // namespace tensorflow

View File

@ -24,8 +24,6 @@ import six
from tensorflow.python import tf2 from tensorflow.python import tf2
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_experimental_dataset_ops from tensorflow.python.ops import gen_experimental_dataset_ops
@ -36,10 +34,10 @@ class ProcessingMode(object):
@staticmethod @staticmethod
def validate(mode): def validate(mode):
"""Raises a TypeError if the given object is not a valid processing mode.""" """Raises a ValueError if the given object is not a valid processing mode."""
valid_modes = [ProcessingMode.PARALLEL_EPOCHS] valid_modes = [ProcessingMode.PARALLEL_EPOCHS]
if mode not in valid_modes: if mode not in valid_modes:
raise TypeError( raise ValueError(
"{0} is not a valid processing mode. Valid modes: {1}".format( "{0} is not a valid processing mode. Valid modes: {1}".format(
mode, valid_modes)) mode, valid_modes))
@ -50,6 +48,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
def __init__(self, def __init__(self,
input_dataset, input_dataset,
dataset_id, dataset_id,
processing_mode,
address, address,
protocol, protocol,
max_outstanding_requests=None, max_outstanding_requests=None,
@ -60,6 +59,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
input_dataset: The input dataset, which should be registered with the input_dataset: The input dataset, which should be registered with the
tf.data service under `dataset_id`. tf.data service under `dataset_id`.
dataset_id: The dataset id for the dataset to read from. dataset_id: The dataset id for the dataset to read from.
processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is
"parallel_epochs".
address: The tf.data service address, e.g. "localhost:5000". address: The tf.data service address, e.g. "localhost:5000".
protocol: The protocol to use for communicating with the tf.data service, protocol: The protocol to use for communicating with the tf.data service,
e.g. "grpc". e.g. "grpc".
@ -77,13 +79,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
self._element_spec = input_dataset.element_spec self._element_spec = input_dataset.element_spec
self._dataset_id = dataset_id
self._address = address
self._protocol = protocol
self._max_outstanding_requests = max_outstanding_requests
self._task_refresh_interval_hint_ms = task_refresh_interval_hint_ms
variant_tensor = gen_experimental_dataset_ops.data_service_dataset( variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
dataset_id=dataset_id,
processing_mode=processing_mode,
address=address, address=address,
protocol=protocol, protocol=protocol,
max_outstanding_requests=max_outstanding_requests, max_outstanding_requests=max_outstanding_requests,
@ -91,18 +90,6 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
**self._flat_structure) **self._flat_structure)
super(_DataServiceDatasetV2, self).__init__(variant_tensor) super(_DataServiceDatasetV2, self).__init__(variant_tensor)
@property
def dataset_id(self):
return self._dataset_id
@property
def address(self):
return self._address
@property
def protocol(self):
return self._protocol
@property @property
def element_spec(self): def element_spec(self):
return self._element_spec return self._element_spec
@ -112,30 +99,20 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
"""A `Dataset` that executes its input through the tf.data service.""" """A `Dataset` that executes its input through the tf.data service."""
@functools.wraps(_DataServiceDatasetV2.__init__) @functools.wraps(_DataServiceDatasetV2.__init__)
def __init__(self, input_dataset, dataset_id, address, protocol, def __init__(self, input_dataset, dataset_id, processing_mode, address,
max_outstanding_requests, task_refresh_interval_hint_ms): protocol, max_outstanding_requests,
task_refresh_interval_hint_ms):
self._wrapped = _DataServiceDatasetV2( self._wrapped = _DataServiceDatasetV2(
input_dataset=input_dataset, input_dataset=input_dataset,
dataset_id=dataset_id, dataset_id=dataset_id,
processing_mode=processing_mode,
address=address, address=address,
protocol=protocol, protocol=protocol,
max_outstanding_requests=max_outstanding_requests, max_outstanding_requests=max_outstanding_requests,
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms) task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
super(_DataServiceDatasetV1, self).__init__(self._wrapped) super(_DataServiceDatasetV1, self).__init__(self._wrapped)
@property
def dataset_id(self):
return self._wrapped.dataset_id
@property
def address(self):
return self._wrapped.address
@property
def protocol(self):
return self._wrapped.protocol
if tf2.enabled(): if tf2.enabled():
_DataServiceDataset = _DataServiceDatasetV2 _DataServiceDataset = _DataServiceDatasetV2
@ -143,7 +120,8 @@ else:
_DataServiceDataset = _DataServiceDatasetV1 _DataServiceDataset = _DataServiceDatasetV1
def _distribute(service, def _distribute(processing_mode,
service,
max_outstanding_requests=None, max_outstanding_requests=None,
task_refresh_interval_hint_ms=None): task_refresh_interval_hint_ms=None):
"""A transformation that moves dataset processing to the tf.data service. """A transformation that moves dataset processing to the tf.data service.
@ -152,6 +130,9 @@ def _distribute(service,
parameters which we do not yet want to add to the public Python API. parameters which we do not yet want to add to the public Python API.
Args: Args:
processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is
"parallel_epochs".
service: A string indicating how to connect to the tf.data service. The service: A string indicating how to connect to the tf.data service. The
string should be in the format <protocol>://<address>, e.g. string should be in the format <protocol>://<address>, e.g.
grpc://localhost:5000. grpc://localhost:5000.
@ -165,6 +146,7 @@ def _distribute(service,
Returns: Returns:
Dataset: A `Dataset` of the elements produced by the data service. Dataset: A `Dataset` of the elements produced by the data service.
""" """
ProcessingMode.validate(processing_mode)
if not isinstance(service, six.string_types): if not isinstance(service, six.string_types):
raise ValueError( raise ValueError(
"service must be a string, but service was of type {0}. service={1}" "service must be a string, but service was of type {0}. service={1}"
@ -197,6 +179,7 @@ def _distribute(service,
return _DataServiceDataset( return _DataServiceDataset(
input_dataset=dataset, input_dataset=dataset,
dataset_id=dataset_id, dataset_id=dataset_id,
processing_mode=processing_mode,
address=address, address=address,
protocol=protocol, protocol=protocol,
max_outstanding_requests=max_outstanding_requests, max_outstanding_requests=max_outstanding_requests,
@ -205,52 +188,9 @@ def _distribute(service,
return _apply_fn return _apply_fn
def distribute(service, max_outstanding_requests=None): def distribute(processing_mode, service, max_outstanding_requests=None):
"""A transformation that moves dataset processing to the tf.data service. """A transformation that moves dataset processing to the tf.data service.
```
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: x*x)
dataset = dataset.apply(
tf.data.experimental.service.distribute("grpc://dataservice:5000"))
dataset = dataset.map(lambda x: x+10)
job_token = tf.data.experimental.service.create_job(dataset)
it = tf.data.experimental.service.create_iterator(dataset, job_token)
for element in it:
# process element
```
In the above example, the first two lines (before the call to `distribute`)
will be executed on tf.data workers, and the elements provided over
RPC. The remaining transformations (after the call to `distribute`) will be
executed locally.
The token returned from `create_job` may be used to create multiple
coordinated iterators which consume data from the same job.
Args:
service: A string indicating how to connect to the tf.data service. The
string should be in the format <protocol>://<address>, e.g.
grpc://localhost:5000.
max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
Returns:
Dataset: A `Dataset` of the elements produced by the data service.
"""
return _distribute(service, max_outstanding_requests)
def create_job(dataset, processing_mode):
"""Creates a job for reading a dataset through the tf.data service.
The returned token can be used to create iterators for consuming data from
the job. `processing_mode` controls what data will be produced. Iterators
created from the same token will consume from the same job.
The `processing_mode` argument controls how data is processed by the The `processing_mode` argument controls how data is processed by the
tf.data service. Currently, the only supported mode is "parallel_epochs". tf.data service. Currently, the only supported mode is "parallel_epochs".
@ -263,77 +203,40 @@ def create_job(dataset, processing_mode):
to randomly shuffle your dataset, so that different tf.data workers will to randomly shuffle your dataset, so that different tf.data workers will
iterate through the dataset in different orders. iterate through the dataset in different orders.
In the future, we plan to add additional epoch modes. For example, we will add In the future, there will be additional epoch modes. For example,
a "one_epoch" mode which partitions the dataset across the tf.data a "one_epoch" mode which partitions the dataset across the tf.data
workers, so that the consumers see each element of the dataset only once. workers, so that the consumers see each element of the dataset only once.
```
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: x*x)
dataset = dataset.apply(
tf.data.experimental.service.distribute("parallel_epochs",
"grpc://dataservice:5000"))
dataset = dataset.map(lambda x: x+10)
for element in dataset:
# process element
```
In the above example, the first two lines (before the call to `distribute`)
will be executed on tf.data workers, and the elements provided over
RPC. The remaining transformations (after the call to `distribute`) will be
executed locally.
Args: Args:
dataset: A `tf.data.Dataset` to create a job for. The dataset must contain a
single `distribute` transformation.
processing_mode: A string specifying the policy for how data should be processing_mode: A string specifying the policy for how data should be
processed by tf.data workers. Currently, the only supported value is processed by tf.data workers. Currently, the only supported value is
"parallel_epochs". "parallel_epochs".
service: A string indicating how to connect to the tf.data service. The
string should be in the format <protocol>://<address>, e.g.
grpc://localhost:5000.
max_outstanding_requests: (Optional.) A limit on how many elements may be
requested at the same time. You can use this option to control the amount
of memory used, since `distribute` won't use more than `element_size` *
`max_outstanding_requests` of memory.
Returns: Returns:
A token for reading from the created tf.data service job. To read using the Dataset: A `Dataset` of the elements produced by the data service.
token, call `create_iterator(dataset, token)`
Raises:
ValueError: If the dataset contains no calls to `distribute` or more than 1
call to `distribute`.
""" """
datasets = _find_data_service_datasets(dataset) return _distribute(processing_mode, service, max_outstanding_requests)
if len(datasets) > 1:
raise ValueError(
"Datasets containing multiple calls to .distribute(...) are " +
"not supported")
if not datasets:
raise ValueError(
"Dataset does not contain any distribute() transformations")
ProcessingMode.validate(processing_mode)
data_service_dataset = datasets[0]
return gen_experimental_dataset_ops.create_job(
data_service_dataset.dataset_id, data_service_dataset.address,
data_service_dataset.protocol, processing_mode)
def create_iterator(dataset, job_token):
"""Creates an iterator for reading from the tf.data service.
Args:
dataset: A `tf.data.Dataset` object.
job_token: A token generated by `create_job`.
Returns:
A dataset iterator.
Raises:
RuntimeError: If called outside of a function in graph mode.
"""
if context.executing_eagerly() or ops.inside_function():
return iterator_ops.OwnedIterator(dataset, job_token=job_token)
else:
raise RuntimeError("create_iterator() is only supported inside of "
"tf.function or when eager execution is enabled.")
def _find_data_service_datasets(dataset):
"""Produces a list of all data service datasets in the given dataset.
Args:
dataset: A `tf.data.Dataset`.
Returns:
A list of all data service datasets.
"""
result = []
to_check = [dataset]
while to_check:
d = to_check.pop()
if isinstance(d, dataset_ops.DatasetV1Adapter):
d = d._dataset # pylint: disable=protected-access
if isinstance(d, _DataServiceDatasetV1) or isinstance(
d, _DataServiceDatasetV2):
result.append(d)
to_check.extend(d._inputs()) # pylint: disable=protected-access
return result

View File

@ -40,7 +40,8 @@ PROTOCOL = "grpc"
def _make_distributed_dataset(dataset, service): def _make_distributed_dataset(dataset, service):
"""Creates a distributed dataset with a short task refresh interval.""" """Creates a distributed dataset with a short task refresh interval."""
return dataset.apply( return dataset.apply(
data_service_ops._distribute(service, task_refresh_interval_hint_ms=20)) data_service_ops._distribute(
"parallel_epochs", service, task_refresh_interval_hint_ms=20))
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase): class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@ -65,27 +66,35 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
return self._master.target return self._master.target
@combinations.generate(test_base.eager_only_combinations())
def testMultipleEpochs(self):
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(3)
ds = _make_distributed_dataset(ds, service)
for _ in range(10):
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
it = data_service_ops.create_iterator(ds, token)
self.assertEqual(list(range(3)), [t.numpy() for t in it])
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeBasic(self): def testDistributeBasic(self):
num_elements = 10 num_elements = 10
service = self.create_cluster(1) service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, service) ds = _make_distributed_dataset(ds, service)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") results = [elem.numpy() for elem in ds]
it = data_service_ops.create_iterator(ds, token)
results = [t.numpy() for t in it]
self.assertEqual(list(range(num_elements)), results) self.assertEqual(list(range(num_elements)), results)
@combinations.generate(test_base.eager_only_combinations())
def testMultipleEpochs(self):
num_elements = 3
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, service)
for _ in range(10):
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
@combinations.generate(test_base.eager_only_combinations())
def testRepeatedDataset(self):
num_elements = 10
num_repetitions = 5
service = self.create_cluster(1)
ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, service)
ds = ds.repeat(num_repetitions)
self.assertDatasetProduces(
ds, expected_output=num_repetitions * list(range(num_elements)))
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testConcurrentEpoch(self): def testConcurrentEpoch(self):
num_elements = 10 num_elements = 10
@ -96,9 +105,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
for _ in range(num_datasets): for _ in range(num_datasets):
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, service) ds = _make_distributed_dataset(ds, service)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterators.append(iter(ds))
it = data_service_ops.create_iterator(ds, token)
iterators.append(it)
results.append([]) results.append([])
for _ in range(num_elements): for _ in range(num_elements):
@ -110,6 +117,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testSharedEpoch(self): def testSharedEpoch(self):
self.skipTest("Not yet implemented")
num_elements = 10 num_elements = 10
num_iterators = 3 num_iterators = 3
service = self.create_cluster(1) service = self.create_cluster(1)
@ -117,9 +125,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
ds = _make_distributed_dataset(ds, service) ds = _make_distributed_dataset(ds, service)
result = [] result = []
iterators = [] iterators = []
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
for _ in range(num_iterators): for _ in range(num_iterators):
iterators.append(data_service_ops.create_iterator(ds, token)) iterators.append(iter(ds))
# Alternate reading between the iterators. # Alternate reading between the iterators.
for _ in range(2): for _ in range(2):
@ -140,9 +147,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
service = self.create_cluster(num_workers) service = self.create_cluster(num_workers)
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, service) ds = _make_distributed_dataset(ds, service)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") results = [elem.numpy() for elem in ds]
iterator = data_service_ops.create_iterator(ds, token)
results = [elem.numpy() for elem in iterator]
self.assertCountEqual(num_workers * list(range(num_elements)), results) self.assertCountEqual(num_workers * list(range(num_elements)), results)
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
@ -154,8 +159,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master.target) ds = _make_distributed_dataset(ds, self._master.target)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = iter(ds)
iterator = data_service_ops.create_iterator(ds, token)
results = [] results = []
# Read halfway through the dataset. # Read halfway through the dataset.
for _ in range(num_elements // 2): for _ in range(num_elements // 2):
@ -184,8 +188,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
num_elements = 100 num_elements = 100
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, self._master.target) ds = _make_distributed_dataset(ds, self._master.target)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") iterator = iter(ds)
iterator = data_service_ops.create_iterator(ds, token)
# Read halfway through the dataset. # Read halfway through the dataset.
midpoint = num_elements // 2 midpoint = num_elements // 2
for i in range(midpoint): for i in range(midpoint):
@ -219,12 +222,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def f(): def f():
ds = dataset_ops.Dataset.range(num_elements) ds = dataset_ops.Dataset.range(num_elements)
ds = _make_distributed_dataset(ds, service) ds = _make_distributed_dataset(ds, service)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
it = data_service_ops.create_iterator(ds, token)
result = tensor_array_ops.TensorArray( result = tensor_array_ops.TensorArray(
dtypes.int64, size=num_workers * num_elements, dynamic_size=True) dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
i = 0 i = 0
for elem in it: for elem in ds:
result = result.write(i, elem) result = result.write(i, elem)
i += 1 i += 1
return result.stack() return result.stack()
@ -243,9 +244,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
service = self.create_cluster(3) service = self.create_cluster(3)
ds = _make_distributed_dataset(ds, service) ds = _make_distributed_dataset(ds, service)
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs") next(iter(ds))
iterator = data_service_ops.create_iterator(ds, token)
next(iterator)
@combinations.generate( @combinations.generate(
combinations.times( combinations.times(
@ -262,27 +261,6 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
with self.assertRaises(errors.FailedPreconditionError): with self.assertRaises(errors.FailedPreconditionError):
self.run_stateful(distribute_options.ExternalStatePolicy.FAIL) self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)
@combinations.generate(test_base.eager_only_combinations())
def testNoDistributeCalls(self):
ds = dataset_ops.Dataset.range(1)
with self.assertRaisesWithLiteralMatch(
ValueError,
"Dataset does not contain any distribute() transformations"):
data_service_ops.create_job(ds, processing_mode="parallel_epochs")
@combinations.generate(test_base.eager_only_combinations())
def testMultipleDistributeCalls(self):
service = self.create_cluster(1)
ds1 = dataset_ops.Dataset.range(1)
ds1 = _make_distributed_dataset(ds1, service)
ds2 = dataset_ops.Dataset.range(1)
ds2 = _make_distributed_dataset(ds2, service)
ds = dataset_ops.Dataset.zip((ds1, ds2))
with self.assertRaisesWithLiteralMatch(
ValueError, "Datasets containing multiple calls to .distribute(...) "
"are not supported"):
data_service_ops.create_job(ds, processing_mode="parallel_epochs")
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeFromInterleave(self): def testDistributeFromInterleave(self):
service = self.create_cluster(1) service = self.create_cluster(1)
@ -302,14 +280,27 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testDistributeNonStringAddresses(self): def testDistributeNonStringAddresses(self):
ds = dataset_ops.Dataset.range(10) ds = dataset_ops.Dataset.range(10)
with self.assertRaisesRegex(ValueError, "service must be a string"): with self.assertRaisesRegex(ValueError, "service must be a string"):
ds = ds.apply(data_service_ops.distribute(service=1)) ds = ds.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs", service=1))
@combinations.generate(test_base.eager_only_combinations()) @combinations.generate(test_base.eager_only_combinations())
def testDistributeEmptyAddress(self): def testDistributeEmptyAddress(self):
ds = dataset_ops.Dataset.range(10) ds = dataset_ops.Dataset.range(10)
with self.assertRaisesWithLiteralMatch(ValueError, with self.assertRaisesWithLiteralMatch(ValueError,
"service must not be empty"): "service must not be empty"):
ds = ds.apply(data_service_ops.distribute(service="")) ds = ds.apply(
data_service_ops.distribute(
processing_mode="parallel_epochs", service=""))
@combinations.generate(test_base.eager_only_combinations())
def testDistributeInvalidProcessingMode(self):
ds = dataset_ops.Dataset.range(10)
with self.assertRaisesRegex(ValueError,
"invalid is not a valid processing mode"):
ds = ds.apply(
data_service_ops.distribute(
processing_mode="invalid", service="grpc://localhost:5000"))
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -840,10 +840,6 @@ tf_module {
name: "CountUpTo" name: "CountUpTo"
argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "CreateJob"
argspec: "args=[\'dataset_id\', \'address\', \'protocol\', \'processing_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "CreateSummaryDbWriter" name: "CreateSummaryDbWriter"
argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -938,7 +934,7 @@ tf_module {
} }
member_method { member_method {
name: "DataServiceDataset" name: "DataServiceDataset"
argspec: "args=[\'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
} }
member_method { member_method {
name: "DatasetCardinality" name: "DatasetCardinality"
@ -2200,10 +2196,6 @@ tf_module {
name: "Lu" name: "Lu"
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], " argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
} }
member_method {
name: "MakeDataServiceIterator"
argspec: "args=[\'dataset\', \'job_token\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "MakeIterator" name: "MakeIterator"
argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "

View File

@ -840,10 +840,6 @@ tf_module {
name: "CountUpTo" name: "CountUpTo"
argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
} }
member_method {
name: "CreateJob"
argspec: "args=[\'dataset_id\', \'address\', \'protocol\', \'processing_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "CreateSummaryDbWriter" name: "CreateSummaryDbWriter"
argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@ -938,7 +934,7 @@ tf_module {
} }
member_method { member_method {
name: "DataServiceDataset" name: "DataServiceDataset"
argspec: "args=[\'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], " argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
} }
member_method { member_method {
name: "DatasetCardinality" name: "DatasetCardinality"
@ -2200,10 +2196,6 @@ tf_module {
name: "Lu" name: "Lu"
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], " argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
} }
member_method {
name: "MakeDataServiceIterator"
argspec: "args=[\'dataset\', \'job_token\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method { member_method {
name: "MakeIterator" name: "MakeIterator"
argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "