[tf.data service] Support __iter__ for tf.data service datasets.
Now we can iterate through tf.data service datasets with `for elem in dataset`, and `distributed_dataset.repeat()` will work correctly. This CL removes the previous method of iteration via CreateJob/CreateDataServiceIterator. It wasn't yet made public, so it is OK to remove the old ops. PiperOrigin-RevId: 309281077 Change-Id: I9531f7d2834ce6669f15896d8c830d23d8277b13
This commit is contained in:
parent
2d26525724
commit
655d6d3cca
@ -1,5 +0,0 @@
|
|||||||
op {
|
|
||||||
graph_op_name: "CreateJob"
|
|
||||||
visibility: HIDDEN
|
|
||||||
summary: "Creates a tf.data service job."
|
|
||||||
}
|
|
@ -1,5 +0,0 @@
|
|||||||
op {
|
|
||||||
graph_op_name: "MakeDataServiceIterator"
|
|
||||||
visibility: HIDDEN
|
|
||||||
summary: "Creates an iterator for reading from the tf.data service."
|
|
||||||
}
|
|
@ -26,6 +26,34 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
constexpr const char kParallelEpochs[] = "parallel_epochs";
|
||||||
|
constexpr const char kOneEpoch[] = "one_epoch";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode) {
|
||||||
|
if (s == kParallelEpochs) {
|
||||||
|
*mode = ProcessingMode::PARALLEL_EPOCHS;
|
||||||
|
} else if (s == kOneEpoch) {
|
||||||
|
*mode = ProcessingMode::ONE_EPOCH;
|
||||||
|
} else {
|
||||||
|
return errors::InvalidArgument("Unrecognized processing mode: ", s);
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ProcessingModeToString(ProcessingMode mode) {
|
||||||
|
switch (mode) {
|
||||||
|
case ProcessingMode::PARALLEL_EPOCHS:
|
||||||
|
return kParallelEpochs;
|
||||||
|
case ProcessingMode::ONE_EPOCH:
|
||||||
|
return kOneEpoch;
|
||||||
|
default:
|
||||||
|
DCHECK(false);
|
||||||
|
return "Unknown";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
Status DataServiceMasterClient::CreateJob(int64 dataset_id,
|
||||||
ProcessingMode processing_mode,
|
ProcessingMode processing_mode,
|
||||||
int64* job_id) {
|
int64* job_id) {
|
||||||
|
@ -33,6 +33,13 @@ enum class ProcessingMode : int64 {
|
|||||||
ONE_EPOCH = 1,
|
ONE_EPOCH = 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Parses a string representing a processing mode and stores the result in
|
||||||
|
// *mode. Returns an InvalidArgument status if the string is not recognized.
|
||||||
|
Status ParseProcessingMode(absl::string_view s, ProcessingMode* mode);
|
||||||
|
|
||||||
|
// Converts a processing mode to its corresponding string.
|
||||||
|
std::string ProcessingModeToString(ProcessingMode mode);
|
||||||
|
|
||||||
// Base class for data service clients. Data service clients are
|
// Base class for data service clients. Data service clients are
|
||||||
// thread-compatible, requiring external synchronization when used from multiple
|
// thread-compatible, requiring external synchronization when used from multiple
|
||||||
// threads.
|
// threads.
|
||||||
|
@ -38,6 +38,30 @@ namespace data {
|
|||||||
namespace {
|
namespace {
|
||||||
constexpr const char kProtocol[] = "grpc+local";
|
constexpr const char kProtocol[] = "grpc+local";
|
||||||
|
|
||||||
|
TEST(DataService, ParseParallelEpochsProcessingMode) {
|
||||||
|
ProcessingMode mode;
|
||||||
|
TF_ASSERT_OK(ParseProcessingMode("parallel_epochs", &mode));
|
||||||
|
EXPECT_EQ(mode, ProcessingMode::PARALLEL_EPOCHS);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DataService, ParseOneEpochProcessingMode) {
|
||||||
|
ProcessingMode mode;
|
||||||
|
TF_ASSERT_OK(ParseProcessingMode("one_epoch", &mode));
|
||||||
|
EXPECT_EQ(mode, ProcessingMode::ONE_EPOCH);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DataService, ParseInvalidProcessingMode) {
|
||||||
|
ProcessingMode mode;
|
||||||
|
Status s = ParseProcessingMode("invalid", &mode);
|
||||||
|
EXPECT_EQ(s.code(), error::Code::INVALID_ARGUMENT);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(DataService, ProcessingModeToString) {
|
||||||
|
EXPECT_EQ("parallel_epochs",
|
||||||
|
ProcessingModeToString(ProcessingMode::PARALLEL_EPOCHS));
|
||||||
|
EXPECT_EQ("one_epoch", ProcessingModeToString(ProcessingMode::ONE_EPOCH));
|
||||||
|
}
|
||||||
|
|
||||||
Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
|
Status CheckWorkerOutput(const std::string& worker_address, int64 task_id,
|
||||||
std::vector<std::vector<Tensor>> expected_output) {
|
std::vector<std::vector<Tensor>> expected_output) {
|
||||||
DataServiceWorkerClient worker(worker_address, kProtocol);
|
DataServiceWorkerClient worker(worker_address, kProtocol);
|
||||||
|
@ -117,11 +117,13 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
}
|
}
|
||||||
std::unique_ptr<standalone::Iterator>& iter = it->second.iterator;
|
std::unique_ptr<standalone::Iterator>& iter = it->second.iterator;
|
||||||
if (iter == nullptr) {
|
if (iter == nullptr) {
|
||||||
|
VLOG(3) << "Task " << request->task_id() << " is already finished";
|
||||||
response->set_end_of_sequence(true);
|
response->set_end_of_sequence(true);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence));
|
TF_RETURN_IF_ERROR(iter->GetNext(&outputs, &end_of_sequence));
|
||||||
if (end_of_sequence) {
|
if (end_of_sequence) {
|
||||||
|
VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
|
||||||
// Release iterator memory and leave a null entry as a tombstone.
|
// Release iterator memory and leave a null entry as a tombstone.
|
||||||
iter.reset();
|
iter.reset();
|
||||||
pending_completed_tasks_.push_back(request->task_id());
|
pending_completed_tasks_.push_back(request->task_id());
|
||||||
@ -130,6 +132,7 @@ Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!end_of_sequence) {
|
if (!end_of_sequence) {
|
||||||
|
VLOG(3) << "Producing an element for task " << request->task_id();
|
||||||
TF_RETURN_IF_ERROR(service_util::Compress(
|
TF_RETURN_IF_ERROR(service_util::Compress(
|
||||||
outputs, response->mutable_compressed_element()));
|
outputs, response->mutable_compressed_element()));
|
||||||
}
|
}
|
||||||
|
@ -292,36 +292,6 @@ class Runner {
|
|||||||
static Runner* get();
|
static Runner* get();
|
||||||
};
|
};
|
||||||
|
|
||||||
// A token for reading from a tf.data service job.
|
|
||||||
class JobToken {
|
|
||||||
public:
|
|
||||||
JobToken() : is_empty_(true) {}
|
|
||||||
|
|
||||||
explicit JobToken(int64 job_id) : job_id_(job_id), is_empty_(false) {}
|
|
||||||
|
|
||||||
bool is_empty() const { return is_empty_; }
|
|
||||||
int64 job_id() const { return job_id_; }
|
|
||||||
string TypeName() const { return "tensorflow::JobToken"; }
|
|
||||||
void Encode(VariantTensorData* data) const {
|
|
||||||
Tensor job_id = Tensor(DT_INT64, TensorShape({}));
|
|
||||||
job_id.scalar<int64>()() = job_id_;
|
|
||||||
*(data->add_tensors()) = job_id;
|
|
||||||
|
|
||||||
Tensor is_empty = Tensor(DT_BOOL, TensorShape({}));
|
|
||||||
is_empty.scalar<bool>()() = is_empty_;
|
|
||||||
*(data->add_tensors()) = is_empty;
|
|
||||||
}
|
|
||||||
bool Decode(const VariantTensorData& data) {
|
|
||||||
job_id_ = data.tensors(0).scalar<int64>()();
|
|
||||||
is_empty_ = data.tensors(1).scalar<bool>()();
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int64 job_id_;
|
|
||||||
bool is_empty_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A cut-down version of `OpKernelContext` for running computations in
|
// A cut-down version of `OpKernelContext` for running computations in
|
||||||
// iterators. Note that we cannot simply use `OpKernelContext` here because we
|
// iterators. Note that we cannot simply use `OpKernelContext` here because we
|
||||||
// might run computation in an iterator whose lifetime is not nested within the
|
// might run computation in an iterator whose lifetime is not nested within the
|
||||||
@ -342,7 +312,6 @@ class IteratorContext {
|
|||||||
env(ctx->env()),
|
env(ctx->env()),
|
||||||
flr(ctx->flr()),
|
flr(ctx->flr()),
|
||||||
function_handle_cache(ctx->function_handle_cache()),
|
function_handle_cache(ctx->function_handle_cache()),
|
||||||
job_token(ctx->job_token()),
|
|
||||||
resource_mgr(ctx->resource_mgr()),
|
resource_mgr(ctx->resource_mgr()),
|
||||||
model(ctx->model()),
|
model(ctx->model()),
|
||||||
runner(*(ctx->runner())),
|
runner(*(ctx->runner())),
|
||||||
@ -401,9 +370,6 @@ class IteratorContext {
|
|||||||
// A FunctionHandleCache that owns all the function handles. Not owned.
|
// A FunctionHandleCache that owns all the function handles. Not owned.
|
||||||
FunctionHandleCache* function_handle_cache = nullptr;
|
FunctionHandleCache* function_handle_cache = nullptr;
|
||||||
|
|
||||||
// A token for reading data from a tf.data service job.
|
|
||||||
JobToken job_token;
|
|
||||||
|
|
||||||
// A resource manager for storing dataset-related state, e.g. random
|
// A resource manager for storing dataset-related state, e.g. random
|
||||||
// seeds or cached tensors. Not owned.
|
// seeds or cached tensors. Not owned.
|
||||||
ResourceMgr* resource_mgr = nullptr;
|
ResourceMgr* resource_mgr = nullptr;
|
||||||
@ -453,8 +419,6 @@ class IteratorContext {
|
|||||||
return params_.function_handle_cache;
|
return params_.function_handle_cache;
|
||||||
}
|
}
|
||||||
|
|
||||||
const JobToken& job_token() { return params_.job_token; }
|
|
||||||
|
|
||||||
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
||||||
|
|
||||||
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
||||||
|
@ -91,27 +91,4 @@ INSTANTIATE_TEST_SUITE_P(
|
|||||||
{_tf_string_, tensor_strs,
|
{_tf_string_, tensor_strs,
|
||||||
static_cast<int64>(sizeof(str) + str.size()) /*bytes*/}}));
|
static_cast<int64>(sizeof(str) + str.size()) /*bytes*/}}));
|
||||||
|
|
||||||
TEST(DatasetTest, JobServiceTokenIsEmpty) {
|
|
||||||
data::JobToken token;
|
|
||||||
EXPECT_TRUE(token.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(DatasetTest, JobTokenHoldsJobId) {
|
|
||||||
int64 job_id = 5;
|
|
||||||
data::JobToken token(job_id);
|
|
||||||
EXPECT_EQ(job_id, token.job_id());
|
|
||||||
EXPECT_FALSE(token.is_empty());
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST(DatasetTest, JobTokenEncodeDecode) {
|
|
||||||
int64 job_id = 5;
|
|
||||||
data::JobToken token(job_id);
|
|
||||||
VariantTensorData data;
|
|
||||||
token.Encode(&data);
|
|
||||||
data::JobToken decoded;
|
|
||||||
decoded.Decode(data);
|
|
||||||
EXPECT_FALSE(token.is_empty());
|
|
||||||
EXPECT_EQ(job_id, token.job_id());
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -43,6 +43,8 @@ namespace tensorflow {
|
|||||||
namespace data {
|
namespace data {
|
||||||
|
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetType;
|
||||||
|
/* static */ constexpr const char* const DataServiceDatasetOp::kDatasetId;
|
||||||
|
/* static */ constexpr const char* const DataServiceDatasetOp::kProcessingMode;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kAddress;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kProtocol;
|
||||||
/* static */ constexpr const char* const
|
/* static */ constexpr const char* const
|
||||||
@ -50,6 +52,7 @@ namespace data {
|
|||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputTypes;
|
||||||
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
|
/* static */ constexpr const char* const DataServiceDatasetOp::kOutputShapes;
|
||||||
|
|
||||||
|
namespace {
|
||||||
// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for
|
// Once we've spent `kRetryTimeoutMicros` in `GetNextInternal`, we will wait for
|
||||||
// the current attempt to complete and perform no more retries.
|
// the current attempt to complete and perform no more retries.
|
||||||
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
|
const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
|
||||||
@ -57,6 +60,8 @@ const int64 kRetryTimeoutMicros = 1000LL * 1000 * 60 * 60; // 60 minutes.
|
|||||||
// Default interval between task list refreshes.
|
// Default interval between task list refreshes.
|
||||||
const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
// Dataset for reading data from the tf.data service non-deterministically.
|
// Dataset for reading data from the tf.data service non-deterministically.
|
||||||
//
|
//
|
||||||
// This dataset interleaves dataset elements produced by multiple tf.data
|
// This dataset interleaves dataset elements produced by multiple tf.data
|
||||||
@ -64,11 +69,14 @@ const int64 kDefaultTaskRefreshIntervalMs = 1000; // 1 second.
|
|||||||
// to read from (in case workers are added or removed).
|
// to read from (in case workers are added or removed).
|
||||||
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
class DataServiceDatasetOp::Dataset : public DatasetBase {
|
||||||
public:
|
public:
|
||||||
Dataset(OpKernelContext* ctx, const std::string& address,
|
Dataset(OpKernelContext* ctx, int64 dataset_id,
|
||||||
|
ProcessingMode processing_mode, const std::string& address,
|
||||||
const std::string& protocol, int64 max_outstanding_requests,
|
const std::string& protocol, int64 max_outstanding_requests,
|
||||||
int64 task_refresh_interval_ms, const DataTypeVector& output_types,
|
int64 task_refresh_interval_ms, const DataTypeVector& output_types,
|
||||||
const std::vector<PartialTensorShape>& output_shapes)
|
const std::vector<PartialTensorShape>& output_shapes)
|
||||||
: DatasetBase(DatasetContext(ctx)),
|
: DatasetBase(DatasetContext(ctx)),
|
||||||
|
dataset_id_(dataset_id),
|
||||||
|
processing_mode_(processing_mode),
|
||||||
address_(address),
|
address_(address),
|
||||||
protocol_(protocol),
|
protocol_(protocol),
|
||||||
max_outstanding_requests_(max_outstanding_requests),
|
max_outstanding_requests_(max_outstanding_requests),
|
||||||
@ -102,6 +110,13 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
Status AsGraphDefInternal(SerializationContext* ctx,
|
Status AsGraphDefInternal(SerializationContext* ctx,
|
||||||
DatasetGraphDefBuilder* b,
|
DatasetGraphDefBuilder* b,
|
||||||
Node** output) const override {
|
Node** output) const override {
|
||||||
|
Node* dataset_id;
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(dataset_id_, &dataset_id));
|
||||||
|
|
||||||
|
Node* processing_mode;
|
||||||
|
tstring processing_mode_str = ProcessingModeToString(processing_mode_);
|
||||||
|
TF_RETURN_IF_ERROR(b->AddScalar(processing_mode_str, &processing_mode));
|
||||||
|
|
||||||
Node* address;
|
Node* address;
|
||||||
TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
|
TF_RETURN_IF_ERROR(b->AddScalar(address_, &address));
|
||||||
|
|
||||||
@ -117,7 +132,9 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
&task_refresh_interval_hint_ms);
|
&task_refresh_interval_hint_ms);
|
||||||
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
b->AddDataset(this, {address, protocol, max_outstanding_requests},
|
b->AddDataset(this,
|
||||||
|
{dataset_id, processing_mode, address, protocol,
|
||||||
|
max_outstanding_requests},
|
||||||
{std::make_pair(kTaskRefreshIntervalHintMs,
|
{std::make_pair(kTaskRefreshIntervalHintMs,
|
||||||
task_refresh_interval_hint_ms)},
|
task_refresh_interval_hint_ms)},
|
||||||
output));
|
output));
|
||||||
@ -132,6 +149,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
|
|
||||||
~Iterator() override {
|
~Iterator() override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
|
VLOG(1) << "Destroying data service dataset iterator for job id "
|
||||||
|
<< job_id_;
|
||||||
cancelled_ = true;
|
cancelled_ = true;
|
||||||
cv_.notify_all();
|
cv_.notify_all();
|
||||||
// Thread destructors will block until the threads finish, no need to wait
|
// Thread destructors will block until the threads finish, no need to wait
|
||||||
@ -141,14 +160,10 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
Status Initialize(IteratorContext* ctx) override {
|
Status Initialize(IteratorContext* ctx) override {
|
||||||
VLOG(3) << "Connecting to " << dataset()->address_
|
VLOG(3) << "Connecting to " << dataset()->address_
|
||||||
<< " in data service dataset op";
|
<< " in data service dataset op";
|
||||||
if (ctx->job_token().is_empty()) {
|
DataServiceMasterClient master(dataset()->address_, dataset()->protocol_);
|
||||||
return errors::FailedPrecondition(
|
TF_RETURN_IF_ERROR(master.CreateJob(
|
||||||
"Expected a job token, but none found. To iterate over a dataset "
|
dataset()->dataset_id_, dataset()->processing_mode_, &job_id_));
|
||||||
"containing a `distribute` transformation, call `create_job`, "
|
VLOG(1) << "Created data service job with id " << job_id_;
|
||||||
"which will return a job token that you should then use to iterate "
|
|
||||||
"over the dataset via `create_iterator(dataset, job_token).`");
|
|
||||||
}
|
|
||||||
job_id_ = ctx->job_token().job_id();
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -175,6 +190,7 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
DCHECK(!results_.empty());
|
DCHECK(!results_.empty());
|
||||||
|
*end_of_sequence = false;
|
||||||
out_tensors->swap(results_.front());
|
out_tensors->swap(results_.front());
|
||||||
results_.pop();
|
results_.pop();
|
||||||
cv_.notify_all();
|
cv_.notify_all();
|
||||||
@ -287,6 +303,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (!task_ids.contains(task_thread->task_id)) {
|
if (!task_ids.contains(task_thread->task_id)) {
|
||||||
|
VLOG(3) << "Marking removed task thread " << task_thread->task_id
|
||||||
|
<< " as finished";
|
||||||
task_thread->end_of_sequence = true;
|
task_thread->end_of_sequence = true;
|
||||||
}
|
}
|
||||||
++it;
|
++it;
|
||||||
@ -315,6 +333,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
{
|
{
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
if (task_handler->end_of_sequence) {
|
if (task_handler->end_of_sequence) {
|
||||||
|
VLOG(3) << "Task thread " << task_handler->task_id
|
||||||
|
<< " reached end_of_sequence";
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
outstanding_requests_--;
|
outstanding_requests_--;
|
||||||
@ -427,6 +447,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
|
|||||||
std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_);
|
std::unique_ptr<Thread> task_thread_manager_ GUARDED_BY(mu_);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const int64 dataset_id_;
|
||||||
|
const ProcessingMode processing_mode_;
|
||||||
const tstring address_;
|
const tstring address_;
|
||||||
const tstring protocol_;
|
const tstring protocol_;
|
||||||
const int64 max_outstanding_requests_;
|
const int64 max_outstanding_requests_;
|
||||||
@ -448,6 +470,16 @@ DataServiceDatasetOp::DataServiceDatasetOp(OpKernelConstruction* ctx)
|
|||||||
|
|
||||||
void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
||||||
DatasetBase** output) {
|
DatasetBase** output) {
|
||||||
|
int64 dataset_id;
|
||||||
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
|
||||||
|
|
||||||
|
tstring processing_mode_str;
|
||||||
|
OP_REQUIRES_OK(
|
||||||
|
ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
|
||||||
|
ProcessingMode processing_mode;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ParseProcessingMode(processing_mode_str, &processing_mode));
|
||||||
|
|
||||||
tstring address;
|
tstring address;
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
|
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
|
||||||
OP_REQUIRES(ctx, !address.empty(),
|
OP_REQUIRES(ctx, !address.empty(),
|
||||||
@ -468,9 +500,10 @@ void DataServiceDatasetOp::MakeDataset(OpKernelContext* ctx,
|
|||||||
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
|
errors::InvalidArgument(kMaxOutstandingRequests, " must be positive or ",
|
||||||
model::kAutotune));
|
model::kAutotune));
|
||||||
|
|
||||||
*output = new Dataset(ctx, address, protocol, max_outstanding_requests,
|
*output =
|
||||||
task_refresh_interval_hint_ms_, output_types_,
|
new Dataset(ctx, dataset_id, processing_mode, address, protocol,
|
||||||
output_shapes_);
|
max_outstanding_requests, task_refresh_interval_hint_ms_,
|
||||||
|
output_types_, output_shapes_);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("DataServiceDataset").Device(DEVICE_CPU),
|
||||||
|
@ -24,6 +24,8 @@ namespace data {
|
|||||||
class DataServiceDatasetOp : public DatasetOpKernel {
|
class DataServiceDatasetOp : public DatasetOpKernel {
|
||||||
public:
|
public:
|
||||||
static constexpr const char* const kDatasetType = "DataService";
|
static constexpr const char* const kDatasetType = "DataService";
|
||||||
|
static constexpr const char* const kDatasetId = "dataset_id";
|
||||||
|
static constexpr const char* const kProcessingMode = "processing_mode";
|
||||||
static constexpr const char* const kAddress = "address";
|
static constexpr const char* const kAddress = "address";
|
||||||
static constexpr const char* const kProtocol = "protocol";
|
static constexpr const char* const kProtocol = "protocol";
|
||||||
static constexpr const char* const kMaxOutstandingRequests =
|
static constexpr const char* const kMaxOutstandingRequests =
|
||||||
|
@ -22,18 +22,6 @@ limitations under the License.
|
|||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace data {
|
namespace data {
|
||||||
namespace {
|
|
||||||
Status ParseProcessingMode(const tstring& s, ProcessingMode* mode) {
|
|
||||||
if (s == "parallel_epochs") {
|
|
||||||
*mode = ProcessingMode::PARALLEL_EPOCHS;
|
|
||||||
} else if (s == "one_epoch") {
|
|
||||||
*mode = ProcessingMode::ONE_EPOCH;
|
|
||||||
} else {
|
|
||||||
return errors::InvalidArgument("Unrecognized processing mode: ", s);
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx)
|
RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx)
|
||||||
: OpKernel(ctx) {
|
: OpKernel(ctx) {
|
||||||
@ -75,62 +63,8 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
|||||||
output_dataset_id() = dataset_id;
|
output_dataset_id() = dataset_id;
|
||||||
}
|
}
|
||||||
|
|
||||||
CreateJobOp::CreateJobOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
|
||||||
|
|
||||||
void CreateJobOp::Compute(OpKernelContext* ctx) {
|
|
||||||
int64 dataset_id;
|
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
|
|
||||||
|
|
||||||
tstring address;
|
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
|
|
||||||
OP_REQUIRES(ctx, !address.empty(),
|
|
||||||
errors::InvalidArgument(kAddress, " must be non-empty."));
|
|
||||||
|
|
||||||
tstring protocol;
|
|
||||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
|
|
||||||
OP_REQUIRES(ctx, !protocol.empty(),
|
|
||||||
errors::InvalidArgument(kProtocol, " must be non-empty."));
|
|
||||||
|
|
||||||
tstring processing_mode_str;
|
|
||||||
OP_REQUIRES_OK(
|
|
||||||
ctx, ParseScalarArgument(ctx, kProcessingMode, &processing_mode_str));
|
|
||||||
ProcessingMode processing_mode;
|
|
||||||
OP_REQUIRES_OK(ctx,
|
|
||||||
ParseProcessingMode(processing_mode_str, &processing_mode));
|
|
||||||
|
|
||||||
DataServiceMasterClient client(address, protocol);
|
|
||||||
int64 job_id;
|
|
||||||
OP_REQUIRES_OK(ctx, client.CreateJob(dataset_id, processing_mode, &job_id));
|
|
||||||
|
|
||||||
JobToken token(job_id);
|
|
||||||
Tensor* output;
|
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
|
|
||||||
auto output_token = output->tensor<Variant, 0>();
|
|
||||||
output_token() = token;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status MakeDataServiceIteratorOp::DoCompute(OpKernelContext* ctx) {
|
|
||||||
DatasetBase* dataset;
|
|
||||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
|
||||||
|
|
||||||
const Tensor* token_tensor;
|
|
||||||
TF_RETURN_IF_ERROR(ctx->input(kJobToken, &token_tensor));
|
|
||||||
JobToken token = *token_tensor->scalar<Variant>()().get<JobToken>();
|
|
||||||
|
|
||||||
IteratorResource* iterator_resource;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 2), &iterator_resource));
|
|
||||||
|
|
||||||
core::ScopedUnref unref_iterator(iterator_resource);
|
|
||||||
|
|
||||||
return iterator_resource->SetIteratorFromDataset(ctx, dataset, token);
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU),
|
REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU),
|
||||||
RegisterDatasetOp);
|
RegisterDatasetOp);
|
||||||
REGISTER_KERNEL_BUILDER(Name("CreateJob").Device(DEVICE_CPU), CreateJobOp);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("MakeDataServiceIterator").Device(DEVICE_CPU),
|
|
||||||
MakeDataServiceIteratorOp);
|
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -44,41 +44,6 @@ class RegisterDatasetOp : public OpKernel {
|
|||||||
SerializationContext::ExternalStatePolicy external_state_policy_;
|
SerializationContext::ExternalStatePolicy external_state_policy_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates a token for reading from the tf.data service.
|
|
||||||
//
|
|
||||||
// The dataset_id input identifies which dataset to create a token for.
|
|
||||||
// The address and protocol inputs are used to connect to the tf.data service
|
|
||||||
// master.
|
|
||||||
// The processing_mode defines how the tf.data service should produce data for
|
|
||||||
// the token.
|
|
||||||
class CreateJobOp : public OpKernel {
|
|
||||||
public:
|
|
||||||
static constexpr const char* const kDatasetId = "dataset_id";
|
|
||||||
static constexpr const char* const kAddress = "address";
|
|
||||||
static constexpr const char* const kProtocol = "protocol";
|
|
||||||
static constexpr const char* const kProcessingMode = "processing_mode";
|
|
||||||
|
|
||||||
explicit CreateJobOp(OpKernelConstruction* ctx);
|
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Creates a new iterator for iterating over a tf.data service dataset.
|
|
||||||
//
|
|
||||||
// The epoch_id input identifies which epoch to read from. Multiple iterators
|
|
||||||
// may read from the same epoch, causing the elements of the epoch to be split
|
|
||||||
// across all iterators.
|
|
||||||
class MakeDataServiceIteratorOp : public MakeIteratorOp {
|
|
||||||
public:
|
|
||||||
static constexpr const char* const kJobToken = "job_token";
|
|
||||||
|
|
||||||
explicit MakeDataServiceIteratorOp(OpKernelConstruction* ctx)
|
|
||||||
: MakeIteratorOp(ctx) {}
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Status DoCompute(OpKernelContext* ctx) override;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace data
|
} // namespace data
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
|
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
|
||||||
|
@ -166,8 +166,7 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||||
DatasetBase* dataset,
|
DatasetBase* dataset) {
|
||||||
JobToken job_token) {
|
|
||||||
std::shared_ptr<State> new_state;
|
std::shared_ptr<State> new_state;
|
||||||
{
|
{
|
||||||
tf_shared_lock l(mu_);
|
tf_shared_lock l(mu_);
|
||||||
@ -180,7 +179,6 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
|||||||
IteratorContext::Params params(ctx);
|
IteratorContext::Params params(ctx);
|
||||||
params.flr = new_state->flr;
|
params.flr = new_state->flr;
|
||||||
params.function_handle_cache = new_state->function_handle_cache.get();
|
params.function_handle_cache = new_state->function_handle_cache.get();
|
||||||
params.job_token = job_token;
|
|
||||||
params.resource_mgr = &new_state->resource_mgr;
|
params.resource_mgr = &new_state->resource_mgr;
|
||||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||||
params.thread_pool = &unbounded_thread_pool_;
|
params.thread_pool = &unbounded_thread_pool_;
|
||||||
@ -532,9 +530,7 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
|
|||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
|
LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
|
||||||
core::ScopedUnref unref_iterator(iterator_resource);
|
core::ScopedUnref unref_iterator(iterator_resource);
|
||||||
JobToken empty_token;
|
return iterator_resource->SetIteratorFromDataset(ctx, dataset);
|
||||||
return iterator_resource->SetIteratorFromDataset(ctx, dataset,
|
|
||||||
/*job_token=*/empty_token);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
|
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
|
||||||
@ -855,9 +851,7 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
|||||||
// factory function.
|
// factory function.
|
||||||
DatasetBase* dataset;
|
DatasetBase* dataset;
|
||||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
|
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
|
||||||
JobToken empty_token;
|
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
|
||||||
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(
|
|
||||||
ctx, dataset, /*job_token=*/empty_token));
|
|
||||||
(*iterator)->Ref();
|
(*iterator)->Ref();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -69,14 +69,9 @@ class IteratorResource : public ResourceBase {
|
|||||||
// Creates an iterator for `dataset`, and associates the iterator with this
|
// Creates an iterator for `dataset`, and associates the iterator with this
|
||||||
// iterator resource.
|
// iterator resource.
|
||||||
//
|
//
|
||||||
// The `job_token` will be passed through the IteratorContext when
|
|
||||||
// creating the iterator. This token is used to read from a tf.data service
|
|
||||||
// job.
|
|
||||||
//
|
|
||||||
// `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
|
// `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
|
||||||
// or `Restore`.
|
// or `Restore`.
|
||||||
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset,
|
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset);
|
||||||
JobToken job_token);
|
|
||||||
|
|
||||||
string DebugString() const override { return "Iterator resource"; }
|
string DebugString() const override { return "Iterator resource"; }
|
||||||
|
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
op {
|
|
||||||
name: "CreateJob"
|
|
||||||
input_arg {
|
|
||||||
name: "dataset_id"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "address"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "protocol"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "processing_mode"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "job_token"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,16 +0,0 @@
|
|||||||
op {
|
|
||||||
name: "MakeDataServiceIterator"
|
|
||||||
input_arg {
|
|
||||||
name: "dataset"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "job_token"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "iterator"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
@ -1,23 +0,0 @@
|
|||||||
op {
|
|
||||||
name: "CreateJob"
|
|
||||||
input_arg {
|
|
||||||
name: "dataset_id"
|
|
||||||
type: DT_INT64
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "address"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "protocol"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "processing_mode"
|
|
||||||
type: DT_STRING
|
|
||||||
}
|
|
||||||
output_arg {
|
|
||||||
name: "job_token"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,16 +0,0 @@
|
|||||||
op {
|
|
||||||
name: "MakeDataServiceIterator"
|
|
||||||
input_arg {
|
|
||||||
name: "dataset"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "job_token"
|
|
||||||
type: DT_VARIANT
|
|
||||||
}
|
|
||||||
input_arg {
|
|
||||||
name: "iterator"
|
|
||||||
type: DT_RESOURCE
|
|
||||||
}
|
|
||||||
is_stateful: true
|
|
||||||
}
|
|
@ -1043,6 +1043,8 @@ REGISTER_OP("ExperimentalUniqueDataset")
|
|||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("DataServiceDataset")
|
REGISTER_OP("DataServiceDataset")
|
||||||
|
.Input("dataset_id: int64")
|
||||||
|
.Input("processing_mode: string")
|
||||||
.Input("address: string")
|
.Input("address: string")
|
||||||
.Input("protocol: string")
|
.Input("protocol: string")
|
||||||
.Input("max_outstanding_requests: int64")
|
.Input("max_outstanding_requests: int64")
|
||||||
@ -1061,18 +1063,4 @@ REGISTER_OP("RegisterDataset")
|
|||||||
.Attr("external_state_policy: int")
|
.Attr("external_state_policy: int")
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
.SetShapeFn(shape_inference::ScalarShape);
|
||||||
|
|
||||||
REGISTER_OP("CreateJob")
|
|
||||||
.Input("dataset_id: int64")
|
|
||||||
.Input("address: string")
|
|
||||||
.Input("protocol: string")
|
|
||||||
.Input("processing_mode: string")
|
|
||||||
.Output("job_token: variant")
|
|
||||||
.SetShapeFn(shape_inference::ScalarShape);
|
|
||||||
|
|
||||||
REGISTER_OP("MakeDataServiceIterator")
|
|
||||||
.Input("dataset: variant")
|
|
||||||
.Input("job_token: variant")
|
|
||||||
.Input("iterator: resource")
|
|
||||||
.SetShapeFn(shape_inference::NoOutputs);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -24,8 +24,6 @@ import six
|
|||||||
from tensorflow.python import tf2
|
from tensorflow.python import tf2
|
||||||
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.ops import iterator_ops
|
|
||||||
from tensorflow.python.eager import context
|
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import gen_experimental_dataset_ops
|
from tensorflow.python.ops import gen_experimental_dataset_ops
|
||||||
@ -36,10 +34,10 @@ class ProcessingMode(object):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate(mode):
|
def validate(mode):
|
||||||
"""Raises a TypeError if the given object is not a valid processing mode."""
|
"""Raises a ValueError if the given object is not a valid processing mode."""
|
||||||
valid_modes = [ProcessingMode.PARALLEL_EPOCHS]
|
valid_modes = [ProcessingMode.PARALLEL_EPOCHS]
|
||||||
if mode not in valid_modes:
|
if mode not in valid_modes:
|
||||||
raise TypeError(
|
raise ValueError(
|
||||||
"{0} is not a valid processing mode. Valid modes: {1}".format(
|
"{0} is not a valid processing mode. Valid modes: {1}".format(
|
||||||
mode, valid_modes))
|
mode, valid_modes))
|
||||||
|
|
||||||
@ -50,6 +48,7 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
input_dataset,
|
input_dataset,
|
||||||
dataset_id,
|
dataset_id,
|
||||||
|
processing_mode,
|
||||||
address,
|
address,
|
||||||
protocol,
|
protocol,
|
||||||
max_outstanding_requests=None,
|
max_outstanding_requests=None,
|
||||||
@ -60,6 +59,9 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
input_dataset: The input dataset, which should be registered with the
|
input_dataset: The input dataset, which should be registered with the
|
||||||
tf.data service under `dataset_id`.
|
tf.data service under `dataset_id`.
|
||||||
dataset_id: The dataset id for the dataset to read from.
|
dataset_id: The dataset id for the dataset to read from.
|
||||||
|
processing_mode: A string specifying the policy for how data should be
|
||||||
|
processed by tf.data workers. Currently, the only supported value is
|
||||||
|
"parallel_epochs".
|
||||||
address: The tf.data service address, e.g. "localhost:5000".
|
address: The tf.data service address, e.g. "localhost:5000".
|
||||||
protocol: The protocol to use for communicating with the tf.data service,
|
protocol: The protocol to use for communicating with the tf.data service,
|
||||||
e.g. "grpc".
|
e.g. "grpc".
|
||||||
@ -77,13 +79,10 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
|
task_refresh_interval_hint_ms = dataset_ops.AUTOTUNE
|
||||||
|
|
||||||
self._element_spec = input_dataset.element_spec
|
self._element_spec = input_dataset.element_spec
|
||||||
self._dataset_id = dataset_id
|
|
||||||
self._address = address
|
|
||||||
self._protocol = protocol
|
|
||||||
self._max_outstanding_requests = max_outstanding_requests
|
|
||||||
self._task_refresh_interval_hint_ms = task_refresh_interval_hint_ms
|
|
||||||
|
|
||||||
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
|
variant_tensor = gen_experimental_dataset_ops.data_service_dataset(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
processing_mode=processing_mode,
|
||||||
address=address,
|
address=address,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
@ -91,18 +90,6 @@ class _DataServiceDatasetV2(dataset_ops.DatasetSource):
|
|||||||
**self._flat_structure)
|
**self._flat_structure)
|
||||||
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
super(_DataServiceDatasetV2, self).__init__(variant_tensor)
|
||||||
|
|
||||||
@property
|
|
||||||
def dataset_id(self):
|
|
||||||
return self._dataset_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def address(self):
|
|
||||||
return self._address
|
|
||||||
|
|
||||||
@property
|
|
||||||
def protocol(self):
|
|
||||||
return self._protocol
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def element_spec(self):
|
def element_spec(self):
|
||||||
return self._element_spec
|
return self._element_spec
|
||||||
@ -112,30 +99,20 @@ class _DataServiceDatasetV1(dataset_ops.DatasetV1Adapter):
|
|||||||
"""A `Dataset` that executes its input through the tf.data service."""
|
"""A `Dataset` that executes its input through the tf.data service."""
|
||||||
|
|
||||||
@functools.wraps(_DataServiceDatasetV2.__init__)
|
@functools.wraps(_DataServiceDatasetV2.__init__)
|
||||||
def __init__(self, input_dataset, dataset_id, address, protocol,
|
def __init__(self, input_dataset, dataset_id, processing_mode, address,
|
||||||
max_outstanding_requests, task_refresh_interval_hint_ms):
|
protocol, max_outstanding_requests,
|
||||||
|
task_refresh_interval_hint_ms):
|
||||||
|
|
||||||
self._wrapped = _DataServiceDatasetV2(
|
self._wrapped = _DataServiceDatasetV2(
|
||||||
input_dataset=input_dataset,
|
input_dataset=input_dataset,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
processing_mode=processing_mode,
|
||||||
address=address,
|
address=address,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
task_refresh_interval_hint_ms=task_refresh_interval_hint_ms)
|
||||||
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
|
super(_DataServiceDatasetV1, self).__init__(self._wrapped)
|
||||||
|
|
||||||
@property
|
|
||||||
def dataset_id(self):
|
|
||||||
return self._wrapped.dataset_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def address(self):
|
|
||||||
return self._wrapped.address
|
|
||||||
|
|
||||||
@property
|
|
||||||
def protocol(self):
|
|
||||||
return self._wrapped.protocol
|
|
||||||
|
|
||||||
|
|
||||||
if tf2.enabled():
|
if tf2.enabled():
|
||||||
_DataServiceDataset = _DataServiceDatasetV2
|
_DataServiceDataset = _DataServiceDatasetV2
|
||||||
@ -143,7 +120,8 @@ else:
|
|||||||
_DataServiceDataset = _DataServiceDatasetV1
|
_DataServiceDataset = _DataServiceDatasetV1
|
||||||
|
|
||||||
|
|
||||||
def _distribute(service,
|
def _distribute(processing_mode,
|
||||||
|
service,
|
||||||
max_outstanding_requests=None,
|
max_outstanding_requests=None,
|
||||||
task_refresh_interval_hint_ms=None):
|
task_refresh_interval_hint_ms=None):
|
||||||
"""A transformation that moves dataset processing to the tf.data service.
|
"""A transformation that moves dataset processing to the tf.data service.
|
||||||
@ -152,6 +130,9 @@ def _distribute(service,
|
|||||||
parameters which we do not yet want to add to the public Python API.
|
parameters which we do not yet want to add to the public Python API.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
processing_mode: A string specifying the policy for how data should be
|
||||||
|
processed by tf.data workers. Currently, the only supported value is
|
||||||
|
"parallel_epochs".
|
||||||
service: A string indicating how to connect to the tf.data service. The
|
service: A string indicating how to connect to the tf.data service. The
|
||||||
string should be in the format <protocol>://<address>, e.g.
|
string should be in the format <protocol>://<address>, e.g.
|
||||||
grpc://localhost:5000.
|
grpc://localhost:5000.
|
||||||
@ -165,6 +146,7 @@ def _distribute(service,
|
|||||||
Returns:
|
Returns:
|
||||||
Dataset: A `Dataset` of the elements produced by the data service.
|
Dataset: A `Dataset` of the elements produced by the data service.
|
||||||
"""
|
"""
|
||||||
|
ProcessingMode.validate(processing_mode)
|
||||||
if not isinstance(service, six.string_types):
|
if not isinstance(service, six.string_types):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"service must be a string, but service was of type {0}. service={1}"
|
"service must be a string, but service was of type {0}. service={1}"
|
||||||
@ -197,6 +179,7 @@ def _distribute(service,
|
|||||||
return _DataServiceDataset(
|
return _DataServiceDataset(
|
||||||
input_dataset=dataset,
|
input_dataset=dataset,
|
||||||
dataset_id=dataset_id,
|
dataset_id=dataset_id,
|
||||||
|
processing_mode=processing_mode,
|
||||||
address=address,
|
address=address,
|
||||||
protocol=protocol,
|
protocol=protocol,
|
||||||
max_outstanding_requests=max_outstanding_requests,
|
max_outstanding_requests=max_outstanding_requests,
|
||||||
@ -205,52 +188,9 @@ def _distribute(service,
|
|||||||
return _apply_fn
|
return _apply_fn
|
||||||
|
|
||||||
|
|
||||||
def distribute(service, max_outstanding_requests=None):
|
def distribute(processing_mode, service, max_outstanding_requests=None):
|
||||||
"""A transformation that moves dataset processing to the tf.data service.
|
"""A transformation that moves dataset processing to the tf.data service.
|
||||||
|
|
||||||
```
|
|
||||||
dataset = tf.data.Dataset.range(10)
|
|
||||||
dataset = dataset.map(lambda x: x*x)
|
|
||||||
dataset = dataset.apply(
|
|
||||||
tf.data.experimental.service.distribute("grpc://dataservice:5000"))
|
|
||||||
dataset = dataset.map(lambda x: x+10)
|
|
||||||
|
|
||||||
job_token = tf.data.experimental.service.create_job(dataset)
|
|
||||||
it = tf.data.experimental.service.create_iterator(dataset, job_token)
|
|
||||||
for element in it:
|
|
||||||
# process element
|
|
||||||
```
|
|
||||||
|
|
||||||
In the above example, the first two lines (before the call to `distribute`)
|
|
||||||
will be executed on tf.data workers, and the elements provided over
|
|
||||||
RPC. The remaining transformations (after the call to `distribute`) will be
|
|
||||||
executed locally.
|
|
||||||
|
|
||||||
The token returned from `create_job` may be used to create multiple
|
|
||||||
coordinated iterators which consume data from the same job.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
service: A string indicating how to connect to the tf.data service. The
|
|
||||||
string should be in the format <protocol>://<address>, e.g.
|
|
||||||
grpc://localhost:5000.
|
|
||||||
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
|
||||||
requested at the same time. You can use this option to control the amount
|
|
||||||
of memory used, since `distribute` won't use more than `element_size` *
|
|
||||||
`max_outstanding_requests` of memory.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dataset: A `Dataset` of the elements produced by the data service.
|
|
||||||
"""
|
|
||||||
return _distribute(service, max_outstanding_requests)
|
|
||||||
|
|
||||||
|
|
||||||
def create_job(dataset, processing_mode):
|
|
||||||
"""Creates a job for reading a dataset through the tf.data service.
|
|
||||||
|
|
||||||
The returned token can be used to create iterators for consuming data from
|
|
||||||
the job. `processing_mode` controls what data will be produced. Iterators
|
|
||||||
created from the same token will consume from the same job.
|
|
||||||
|
|
||||||
The `processing_mode` argument controls how data is processed by the
|
The `processing_mode` argument controls how data is processed by the
|
||||||
tf.data service. Currently, the only supported mode is "parallel_epochs".
|
tf.data service. Currently, the only supported mode is "parallel_epochs".
|
||||||
|
|
||||||
@ -263,77 +203,40 @@ def create_job(dataset, processing_mode):
|
|||||||
to randomly shuffle your dataset, so that different tf.data workers will
|
to randomly shuffle your dataset, so that different tf.data workers will
|
||||||
iterate through the dataset in different orders.
|
iterate through the dataset in different orders.
|
||||||
|
|
||||||
In the future, we plan to add additional epoch modes. For example, we will add
|
In the future, there will be additional epoch modes. For example,
|
||||||
a "one_epoch" mode which partitions the dataset across the tf.data
|
a "one_epoch" mode which partitions the dataset across the tf.data
|
||||||
workers, so that the consumers see each element of the dataset only once.
|
workers, so that the consumers see each element of the dataset only once.
|
||||||
|
|
||||||
|
```
|
||||||
|
dataset = tf.data.Dataset.range(10)
|
||||||
|
dataset = dataset.map(lambda x: x*x)
|
||||||
|
dataset = dataset.apply(
|
||||||
|
tf.data.experimental.service.distribute("parallel_epochs",
|
||||||
|
"grpc://dataservice:5000"))
|
||||||
|
dataset = dataset.map(lambda x: x+10)
|
||||||
|
|
||||||
|
for element in dataset:
|
||||||
|
# process element
|
||||||
|
```
|
||||||
|
|
||||||
|
In the above example, the first two lines (before the call to `distribute`)
|
||||||
|
will be executed on tf.data workers, and the elements provided over
|
||||||
|
RPC. The remaining transformations (after the call to `distribute`) will be
|
||||||
|
executed locally.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: A `tf.data.Dataset` to create a job for. The dataset must contain a
|
|
||||||
single `distribute` transformation.
|
|
||||||
processing_mode: A string specifying the policy for how data should be
|
processing_mode: A string specifying the policy for how data should be
|
||||||
processed by tf.data workers. Currently, the only supported value is
|
processed by tf.data workers. Currently, the only supported value is
|
||||||
"parallel_epochs".
|
"parallel_epochs".
|
||||||
|
service: A string indicating how to connect to the tf.data service. The
|
||||||
|
string should be in the format <protocol>://<address>, e.g.
|
||||||
|
grpc://localhost:5000.
|
||||||
|
max_outstanding_requests: (Optional.) A limit on how many elements may be
|
||||||
|
requested at the same time. You can use this option to control the amount
|
||||||
|
of memory used, since `distribute` won't use more than `element_size` *
|
||||||
|
`max_outstanding_requests` of memory.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A token for reading from the created tf.data service job. To read using the
|
Dataset: A `Dataset` of the elements produced by the data service.
|
||||||
token, call `create_iterator(dataset, token)`
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the dataset contains no calls to `distribute` or more than 1
|
|
||||||
call to `distribute`.
|
|
||||||
"""
|
"""
|
||||||
datasets = _find_data_service_datasets(dataset)
|
return _distribute(processing_mode, service, max_outstanding_requests)
|
||||||
if len(datasets) > 1:
|
|
||||||
raise ValueError(
|
|
||||||
"Datasets containing multiple calls to .distribute(...) are " +
|
|
||||||
"not supported")
|
|
||||||
if not datasets:
|
|
||||||
raise ValueError(
|
|
||||||
"Dataset does not contain any distribute() transformations")
|
|
||||||
ProcessingMode.validate(processing_mode)
|
|
||||||
data_service_dataset = datasets[0]
|
|
||||||
return gen_experimental_dataset_ops.create_job(
|
|
||||||
data_service_dataset.dataset_id, data_service_dataset.address,
|
|
||||||
data_service_dataset.protocol, processing_mode)
|
|
||||||
|
|
||||||
|
|
||||||
def create_iterator(dataset, job_token):
|
|
||||||
"""Creates an iterator for reading from the tf.data service.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: A `tf.data.Dataset` object.
|
|
||||||
job_token: A token generated by `create_job`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dataset iterator.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
RuntimeError: If called outside of a function in graph mode.
|
|
||||||
"""
|
|
||||||
if context.executing_eagerly() or ops.inside_function():
|
|
||||||
return iterator_ops.OwnedIterator(dataset, job_token=job_token)
|
|
||||||
else:
|
|
||||||
raise RuntimeError("create_iterator() is only supported inside of "
|
|
||||||
"tf.function or when eager execution is enabled.")
|
|
||||||
|
|
||||||
|
|
||||||
def _find_data_service_datasets(dataset):
|
|
||||||
"""Produces a list of all data service datasets in the given dataset.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
dataset: A `tf.data.Dataset`.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A list of all data service datasets.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
to_check = [dataset]
|
|
||||||
while to_check:
|
|
||||||
d = to_check.pop()
|
|
||||||
if isinstance(d, dataset_ops.DatasetV1Adapter):
|
|
||||||
d = d._dataset # pylint: disable=protected-access
|
|
||||||
if isinstance(d, _DataServiceDatasetV1) or isinstance(
|
|
||||||
d, _DataServiceDatasetV2):
|
|
||||||
result.append(d)
|
|
||||||
to_check.extend(d._inputs()) # pylint: disable=protected-access
|
|
||||||
return result
|
|
||||||
|
@ -40,7 +40,8 @@ PROTOCOL = "grpc"
|
|||||||
def _make_distributed_dataset(dataset, service):
|
def _make_distributed_dataset(dataset, service):
|
||||||
"""Creates a distributed dataset with a short task refresh interval."""
|
"""Creates a distributed dataset with a short task refresh interval."""
|
||||||
return dataset.apply(
|
return dataset.apply(
|
||||||
data_service_ops._distribute(service, task_refresh_interval_hint_ms=20))
|
data_service_ops._distribute(
|
||||||
|
"parallel_epochs", service, task_refresh_interval_hint_ms=20))
|
||||||
|
|
||||||
|
|
||||||
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||||
@ -65,27 +66,35 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
return self._master.target
|
return self._master.target
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testMultipleEpochs(self):
|
|
||||||
service = self.create_cluster(1)
|
|
||||||
ds = dataset_ops.Dataset.range(3)
|
|
||||||
ds = _make_distributed_dataset(ds, service)
|
|
||||||
for _ in range(10):
|
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
|
||||||
it = data_service_ops.create_iterator(ds, token)
|
|
||||||
self.assertEqual(list(range(3)), [t.numpy() for t in it])
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeBasic(self):
|
def testDistributeBasic(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
service = self.create_cluster(1)
|
service = self.create_cluster(1)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, service)
|
ds = _make_distributed_dataset(ds, service)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
results = [elem.numpy() for elem in ds]
|
||||||
it = data_service_ops.create_iterator(ds, token)
|
|
||||||
results = [t.numpy() for t in it]
|
|
||||||
self.assertEqual(list(range(num_elements)), results)
|
self.assertEqual(list(range(num_elements)), results)
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testMultipleEpochs(self):
|
||||||
|
num_elements = 3
|
||||||
|
service = self.create_cluster(1)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds = _make_distributed_dataset(ds, service)
|
||||||
|
for _ in range(10):
|
||||||
|
self.assertEqual(list(range(num_elements)), [elem.numpy() for elem in ds])
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testRepeatedDataset(self):
|
||||||
|
num_elements = 10
|
||||||
|
num_repetitions = 5
|
||||||
|
service = self.create_cluster(1)
|
||||||
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
|
ds = _make_distributed_dataset(ds, service)
|
||||||
|
ds = ds.repeat(num_repetitions)
|
||||||
|
self.assertDatasetProduces(
|
||||||
|
ds, expected_output=num_repetitions * list(range(num_elements)))
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testConcurrentEpoch(self):
|
def testConcurrentEpoch(self):
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
@ -96,9 +105,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
for _ in range(num_datasets):
|
for _ in range(num_datasets):
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, service)
|
ds = _make_distributed_dataset(ds, service)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
iterators.append(iter(ds))
|
||||||
it = data_service_ops.create_iterator(ds, token)
|
|
||||||
iterators.append(it)
|
|
||||||
results.append([])
|
results.append([])
|
||||||
|
|
||||||
for _ in range(num_elements):
|
for _ in range(num_elements):
|
||||||
@ -110,6 +117,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testSharedEpoch(self):
|
def testSharedEpoch(self):
|
||||||
|
self.skipTest("Not yet implemented")
|
||||||
num_elements = 10
|
num_elements = 10
|
||||||
num_iterators = 3
|
num_iterators = 3
|
||||||
service = self.create_cluster(1)
|
service = self.create_cluster(1)
|
||||||
@ -117,9 +125,8 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
ds = _make_distributed_dataset(ds, service)
|
ds = _make_distributed_dataset(ds, service)
|
||||||
result = []
|
result = []
|
||||||
iterators = []
|
iterators = []
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
|
||||||
for _ in range(num_iterators):
|
for _ in range(num_iterators):
|
||||||
iterators.append(data_service_ops.create_iterator(ds, token))
|
iterators.append(iter(ds))
|
||||||
|
|
||||||
# Alternate reading between the iterators.
|
# Alternate reading between the iterators.
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
@ -140,9 +147,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
service = self.create_cluster(num_workers)
|
service = self.create_cluster(num_workers)
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, service)
|
ds = _make_distributed_dataset(ds, service)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
results = [elem.numpy() for elem in ds]
|
||||||
iterator = data_service_ops.create_iterator(ds, token)
|
|
||||||
results = [elem.numpy() for elem in iterator]
|
|
||||||
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
self.assertCountEqual(num_workers * list(range(num_elements)), results)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
@ -154,8 +159,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, self._master.target)
|
ds = _make_distributed_dataset(ds, self._master.target)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
iterator = iter(ds)
|
||||||
iterator = data_service_ops.create_iterator(ds, token)
|
|
||||||
results = []
|
results = []
|
||||||
# Read halfway through the dataset.
|
# Read halfway through the dataset.
|
||||||
for _ in range(num_elements // 2):
|
for _ in range(num_elements // 2):
|
||||||
@ -184,8 +188,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
num_elements = 100
|
num_elements = 100
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, self._master.target)
|
ds = _make_distributed_dataset(ds, self._master.target)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
iterator = iter(ds)
|
||||||
iterator = data_service_ops.create_iterator(ds, token)
|
|
||||||
# Read halfway through the dataset.
|
# Read halfway through the dataset.
|
||||||
midpoint = num_elements // 2
|
midpoint = num_elements // 2
|
||||||
for i in range(midpoint):
|
for i in range(midpoint):
|
||||||
@ -219,12 +222,10 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def f():
|
def f():
|
||||||
ds = dataset_ops.Dataset.range(num_elements)
|
ds = dataset_ops.Dataset.range(num_elements)
|
||||||
ds = _make_distributed_dataset(ds, service)
|
ds = _make_distributed_dataset(ds, service)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
|
||||||
it = data_service_ops.create_iterator(ds, token)
|
|
||||||
result = tensor_array_ops.TensorArray(
|
result = tensor_array_ops.TensorArray(
|
||||||
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
dtypes.int64, size=num_workers * num_elements, dynamic_size=True)
|
||||||
i = 0
|
i = 0
|
||||||
for elem in it:
|
for elem in ds:
|
||||||
result = result.write(i, elem)
|
result = result.write(i, elem)
|
||||||
i += 1
|
i += 1
|
||||||
return result.stack()
|
return result.stack()
|
||||||
@ -243,9 +244,7 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
|
|
||||||
service = self.create_cluster(3)
|
service = self.create_cluster(3)
|
||||||
ds = _make_distributed_dataset(ds, service)
|
ds = _make_distributed_dataset(ds, service)
|
||||||
token = data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
next(iter(ds))
|
||||||
iterator = data_service_ops.create_iterator(ds, token)
|
|
||||||
next(iterator)
|
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(
|
combinations.times(
|
||||||
@ -262,27 +261,6 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
with self.assertRaises(errors.FailedPreconditionError):
|
with self.assertRaises(errors.FailedPreconditionError):
|
||||||
self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)
|
self.run_stateful(distribute_options.ExternalStatePolicy.FAIL)
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testNoDistributeCalls(self):
|
|
||||||
ds = dataset_ops.Dataset.range(1)
|
|
||||||
with self.assertRaisesWithLiteralMatch(
|
|
||||||
ValueError,
|
|
||||||
"Dataset does not contain any distribute() transformations"):
|
|
||||||
data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
|
||||||
def testMultipleDistributeCalls(self):
|
|
||||||
service = self.create_cluster(1)
|
|
||||||
ds1 = dataset_ops.Dataset.range(1)
|
|
||||||
ds1 = _make_distributed_dataset(ds1, service)
|
|
||||||
ds2 = dataset_ops.Dataset.range(1)
|
|
||||||
ds2 = _make_distributed_dataset(ds2, service)
|
|
||||||
ds = dataset_ops.Dataset.zip((ds1, ds2))
|
|
||||||
with self.assertRaisesWithLiteralMatch(
|
|
||||||
ValueError, "Datasets containing multiple calls to .distribute(...) "
|
|
||||||
"are not supported"):
|
|
||||||
data_service_ops.create_job(ds, processing_mode="parallel_epochs")
|
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeFromInterleave(self):
|
def testDistributeFromInterleave(self):
|
||||||
service = self.create_cluster(1)
|
service = self.create_cluster(1)
|
||||||
@ -302,14 +280,27 @@ class DataServiceOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
def testDistributeNonStringAddresses(self):
|
def testDistributeNonStringAddresses(self):
|
||||||
ds = dataset_ops.Dataset.range(10)
|
ds = dataset_ops.Dataset.range(10)
|
||||||
with self.assertRaisesRegex(ValueError, "service must be a string"):
|
with self.assertRaisesRegex(ValueError, "service must be a string"):
|
||||||
ds = ds.apply(data_service_ops.distribute(service=1))
|
ds = ds.apply(
|
||||||
|
data_service_ops.distribute(
|
||||||
|
processing_mode="parallel_epochs", service=1))
|
||||||
|
|
||||||
@combinations.generate(test_base.eager_only_combinations())
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
def testDistributeEmptyAddress(self):
|
def testDistributeEmptyAddress(self):
|
||||||
ds = dataset_ops.Dataset.range(10)
|
ds = dataset_ops.Dataset.range(10)
|
||||||
with self.assertRaisesWithLiteralMatch(ValueError,
|
with self.assertRaisesWithLiteralMatch(ValueError,
|
||||||
"service must not be empty"):
|
"service must not be empty"):
|
||||||
ds = ds.apply(data_service_ops.distribute(service=""))
|
ds = ds.apply(
|
||||||
|
data_service_ops.distribute(
|
||||||
|
processing_mode="parallel_epochs", service=""))
|
||||||
|
|
||||||
|
@combinations.generate(test_base.eager_only_combinations())
|
||||||
|
def testDistributeInvalidProcessingMode(self):
|
||||||
|
ds = dataset_ops.Dataset.range(10)
|
||||||
|
with self.assertRaisesRegex(ValueError,
|
||||||
|
"invalid is not a valid processing mode"):
|
||||||
|
ds = ds.apply(
|
||||||
|
data_service_ops.distribute(
|
||||||
|
processing_mode="invalid", service="grpc://localhost:5000"))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -840,10 +840,6 @@ tf_module {
|
|||||||
name: "CountUpTo"
|
name: "CountUpTo"
|
||||||
argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "CreateJob"
|
|
||||||
argspec: "args=[\'dataset_id\', \'address\', \'protocol\', \'processing_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "CreateSummaryDbWriter"
|
name: "CreateSummaryDbWriter"
|
||||||
argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -938,7 +934,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DataServiceDataset"
|
name: "DataServiceDataset"
|
||||||
argspec: "args=[\'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DatasetCardinality"
|
name: "DatasetCardinality"
|
||||||
@ -2200,10 +2196,6 @@ tf_module {
|
|||||||
name: "Lu"
|
name: "Lu"
|
||||||
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "MakeDataServiceIterator"
|
|
||||||
argspec: "args=[\'dataset\', \'job_token\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "MakeIterator"
|
name: "MakeIterator"
|
||||||
argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
@ -840,10 +840,6 @@ tf_module {
|
|||||||
name: "CountUpTo"
|
name: "CountUpTo"
|
||||||
argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'ref\', \'limit\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "CreateJob"
|
|
||||||
argspec: "args=[\'dataset_id\', \'address\', \'protocol\', \'processing_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "CreateSummaryDbWriter"
|
name: "CreateSummaryDbWriter"
|
||||||
argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'writer\', \'db_uri\', \'experiment_name\', \'run_name\', \'user_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
@ -938,7 +934,7 @@ tf_module {
|
|||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DataServiceDataset"
|
name: "DataServiceDataset"
|
||||||
argspec: "args=[\'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
argspec: "args=[\'dataset_id\', \'processing_mode\', \'address\', \'protocol\', \'max_outstanding_requests\', \'output_types\', \'output_shapes\', \'task_refresh_interval_hint_ms\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "DatasetCardinality"
|
name: "DatasetCardinality"
|
||||||
@ -2200,10 +2196,6 @@ tf_module {
|
|||||||
name: "Lu"
|
name: "Lu"
|
||||||
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
argspec: "args=[\'input\', \'output_idx_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
|
||||||
}
|
}
|
||||||
member_method {
|
|
||||||
name: "MakeDataServiceIterator"
|
|
||||||
argspec: "args=[\'dataset\', \'job_token\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
|
||||||
}
|
|
||||||
member_method {
|
member_method {
|
||||||
name: "MakeIterator"
|
name: "MakeIterator"
|
||||||
argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
argspec: "args=[\'dataset\', \'iterator\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
|
||||||
|
Loading…
Reference in New Issue
Block a user