[tf.data service] Implement dataset registration and iteration creation ops.

PiperOrigin-RevId: 304264765
Change-Id: Iaaa3ea3f8e125f287b67a985bef4d8f8fb658803
This commit is contained in:
Andrew Audibert 2020-04-01 14:37:53 -07:00 committed by TensorFlower Gardener
parent 4e6905a35a
commit c1f9a95117
8 changed files with 294 additions and 13 deletions

View File

@ -4,6 +4,7 @@ load(
"tf_additional_all_protos",
"tf_proto_library",
)
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
load(
"//tensorflow:tensorflow.bzl",
"cc_header_only_library",
@ -59,7 +60,6 @@ cc_library(
":master_proto_cc",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow:grpc++",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:core_cpu",
@ -71,6 +71,7 @@ cc_library(
"//tensorflow/core/kernels/data:dataset_utils",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
tf_grpc_cc_dependency(),
],
)
@ -88,7 +89,6 @@ cc_library(
":master_cc_grpc_proto",
":master_proto_cc",
":worker_proto_cc",
"//tensorflow:grpc++",
"//tensorflow/c:c_api_internal",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:core_cpu",
@ -100,6 +100,7 @@ cc_library(
"//tensorflow/core/data:standalone",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
tf_grpc_cc_dependency(),
],
)
@ -110,8 +111,8 @@ cc_library(
"grpc_util.h",
],
deps = [
"//tensorflow:grpc++",
"//tensorflow/core:lib",
tf_grpc_cc_dependency(),
],
)
@ -165,9 +166,9 @@ cc_library(
srcs = ["credentials_factory.cc"],
hdrs = ["credentials_factory.h"],
deps = [
"//tensorflow:grpc++",
"//tensorflow/core:lib",
"@com_google_absl//absl/strings",
tf_grpc_cc_dependency(),
],
)
@ -189,7 +190,7 @@ cc_library(
srcs = ["local_credentials_factory.cc"],
deps = [
":credentials_factory",
"//tensorflow:grpc++",
tf_grpc_cc_dependency(),
],
alwayslink = 1,
)
@ -231,8 +232,8 @@ cc_library(
deps = [
":master_cc_grpc_proto",
":master_impl",
"//tensorflow:grpc++",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
tf_grpc_cc_dependency(),
],
)
@ -243,8 +244,8 @@ cc_library(
deps = [
":worker_cc_grpc_proto",
":worker_impl",
"//tensorflow:grpc++",
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
tf_grpc_cc_dependency(),
],
)
@ -266,9 +267,9 @@ cc_library(
":credentials_factory",
":grpc_master_impl",
":grpc_worker_impl",
"//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
tf_grpc_cc_dependency(),
],
)
@ -288,12 +289,12 @@ tf_cc_test(
":test_util",
":worker_cc_grpc_proto",
":worker_proto_cc",
"//tensorflow:grpc++",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/kernels/data:dataset_test_base",
"@com_google_absl//absl/strings",
tf_grpc_cc_dependency(),
],
)

View File

@ -303,6 +303,10 @@ class Runner {
// are doing is safe. We should formalize the properties here.
class IteratorContext {
public:
// Epoch IDs are only used for tf.data service datasets. Other datasets use
// an epoch ID value of -1.
static constexpr const int64 kNoEpochId = -1;
struct Params {
explicit Params(IteratorContext* ctx)
: allocator_getter(ctx->allocator_getter()),
@ -310,6 +314,7 @@ class IteratorContext {
env(ctx->env()),
flr(ctx->flr()),
function_handle_cache(ctx->function_handle_cache()),
epoch_id(ctx->epoch_id()),
resource_mgr(ctx->resource_mgr()),
model(ctx->model()),
runner(*(ctx->runner())),
@ -366,6 +371,10 @@ class IteratorContext {
// A FunctionHandleCache that owns all the function handles. Not owned.
FunctionHandleCache* function_handle_cache = nullptr;
// Identifies the epoch this iterator was created for. It is used for
// reading from the tf.data service.
int64 epoch_id = kNoEpochId;
// A resource manager for storing dataset-related state, e.g. random
// seeds or cached tensors. Not owned.
ResourceMgr* resource_mgr = nullptr;
@ -415,6 +424,8 @@ class IteratorContext {
return params_.function_handle_cache;
}
int64 epoch_id() { return params_.epoch_id; }
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
const std::shared_ptr<model::Model>& model() { return params_.model; }

View File

@ -1,6 +1,7 @@
# Description:
# Contains experimental kernels for datasets and iterators.
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
@ -120,6 +121,25 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "data_service_ops",
srcs = ["data_service_ops.cc"],
hdrs = ["data_service_ops.h"],
deps = [
"//tensorflow/core:experimental_dataset_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core/data/service:credentials_factory",
"//tensorflow/core/data/service:grpc_util",
"//tensorflow/core/data/service:master_cc_grpc_proto",
"//tensorflow/core/data/service:master_proto_cc",
"//tensorflow/core/kernels/data:dataset_utils",
"//tensorflow/core/kernels/data:iterator_ops",
tf_grpc_cc_dependency(),
],
)
tf_kernel_library(
name = "dense_to_sparse_batch_dataset_op",
srcs = ["dense_to_sparse_batch_dataset_op.cc"],

View File

@ -0,0 +1,141 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/experimental/data_service_ops.h"
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "tensorflow/core/data/service/credentials_factory.h"
#include "tensorflow/core/data/service/grpc_util.h"
#include "tensorflow/core/data/service/master.grpc.pb.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
namespace tensorflow {
namespace data {
RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
int64 external_state_policy_int;
OP_REQUIRES_OK(
ctx, ctx->GetAttr(kExternalStatePolicy, &external_state_policy_int));
external_state_policy_ =
SerializationContext::ExternalStatePolicy(external_state_policy_int);
}
void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
tstring address;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
OP_REQUIRES(ctx, !address.empty(),
errors::InvalidArgument(kAddress, " must be non-empty."));
tstring protocol;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
OP_REQUIRES(ctx, !protocol.empty(),
errors::InvalidArgument(kProtocol, " must be non-empty."));
SerializationContext::Params params;
params.external_state_policy = external_state_policy_;
SerializationContext serialization_ctx(params);
GraphDef graph_def;
OP_REQUIRES_OK(
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
// ::grpc::ChannelArguments args;
std::shared_ptr<::grpc::ChannelCredentials> credentials;
OP_REQUIRES_OK(
ctx, CredentialsFactory::CreateClientCredentials(protocol, &credentials));
auto channel = ::grpc::CreateChannel(address, credentials);
auto master_stub = MasterService::NewStub(channel);
GetOrRegisterDatasetRequest req;
*req.mutable_dataset()->mutable_graph() = graph_def;
GetOrRegisterDatasetResponse resp;
grpc::ClientContext client_ctx;
auto status = master_stub->GetOrRegisterDataset(&client_ctx, req, &resp);
if (!status.ok()) {
ctx->CtxFailure(grpc_util::WrapError("Failed to register dataset", status));
return;
}
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
auto output_dataset_id = output->tensor<int64, 0>();
output_dataset_id() = resp.dataset_id();
}
BeginEpochOp::BeginEpochOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void BeginEpochOp::Compute(OpKernelContext* ctx) {
int64 dataset_id;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
tstring address;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
OP_REQUIRES(ctx, !address.empty(),
errors::InvalidArgument(kAddress, " must be non-empty."));
tstring protocol;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
OP_REQUIRES(ctx, !protocol.empty(),
errors::InvalidArgument(kProtocol, " must be non-empty."));
std::shared_ptr<::grpc::ChannelCredentials> credentials;
OP_REQUIRES_OK(
ctx, CredentialsFactory::CreateClientCredentials(protocol, &credentials));
auto channel = ::grpc::CreateChannel(address, credentials);
auto master_stub = MasterService::NewStub(channel);
BeginEpochRequest req;
req.set_dataset_id(dataset_id);
BeginEpochResponse resp;
grpc::ClientContext client_ctx;
auto status = master_stub->BeginEpoch(&client_ctx, req, &resp);
if (!status.ok()) {
ctx->CtxFailure(grpc_util::WrapError(
absl::StrCat("Failed to begin epoch for dataset id ", dataset_id),
status));
return;
}
Tensor* output;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
auto output_epoch_id = output->tensor<int64, 0>();
output_epoch_id() = resp.epoch_id();
}
Status MakeDataServiceIteratorOp::DoCompute(OpKernelContext* ctx) {
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
const Tensor* epoch_id_tensor;
TF_RETURN_IF_ERROR(ctx->input(kEpochId, &epoch_id_tensor));
int64 epoch_id = epoch_id_tensor->scalar<int64>()();
IteratorResource* iterator_resource;
TF_RETURN_IF_ERROR(
LookupResource(ctx, HandleFromInput(ctx, 2), &iterator_resource));
core::ScopedUnref unref_iterator(iterator_resource);
return iterator_resource->SetIteratorFromDataset(ctx, dataset, epoch_id);
}
REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU),
RegisterDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BeginEpoch").Device(DEVICE_CPU), BeginEpochOp);
REGISTER_KERNEL_BUILDER(Name("MakeDataServiceIterator").Device(DEVICE_CPU),
MakeDataServiceIteratorOp);
} // namespace data
} // namespace tensorflow

View File

@ -0,0 +1,82 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
namespace tensorflow {
namespace data {
// Registers a dataset with the tf.data service.
//
// The address and protocol inputs are used to connect to the tf.data master.
// The external state policy attribute determines whether to ignore, warn, or
// error out when the dataset contains external state.
// The op produces a dataset id for identifying the registered dataset.
class RegisterDatasetOp : public OpKernel {
public:
static constexpr const char* const kAddress = "address";
static constexpr const char* const kProtocol = "protocol";
static constexpr const char* const kExternalStatePolicy =
"external_state_policy";
explicit RegisterDatasetOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
private:
SerializationContext::ExternalStatePolicy external_state_policy_;
};
// Begins a new epoch for a tf.data service dataset.
//
// The dataset_id input identifies which dataset to start a new epoch for.
// The address and protocol inputs are used to connect to the tf.data service
// master.
// The op produces an epoch id to identify the newly created epoch.
class BeginEpochOp : public OpKernel {
public:
static constexpr const char* const kDatasetId = "dataset_id";
static constexpr const char* const kAddress = "address";
static constexpr const char* const kProtocol = "protocol";
explicit BeginEpochOp(OpKernelConstruction* ctx);
void Compute(OpKernelContext* ctx) override;
};
// Creates a new iterator for iterating over a tf.data service dataset.
//
// The epoch_id input identifies which epoch to read from. Multiple iterators
// may read from the same epoch, causing the elements of the epoch to be split
// across all iterators.
class MakeDataServiceIteratorOp : public MakeIteratorOp {
public:
static constexpr const char* const kEpochId = "epoch_id";
explicit MakeDataServiceIteratorOp(OpKernelConstruction* ctx)
: MakeIteratorOp(ctx) {}
protected:
Status DoCompute(OpKernelContext* ctx) override;
};
} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_

View File

@ -166,7 +166,8 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
}
Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
DatasetBase* dataset) {
DatasetBase* dataset,
int64 epoch_id) {
std::shared_ptr<State> new_state;
{
tf_shared_lock l(mu_);
@ -179,6 +180,7 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
IteratorContext::Params params(ctx);
params.flr = new_state->flr;
params.function_handle_cache = new_state->function_handle_cache.get();
params.epoch_id = epoch_id;
params.resource_mgr = &new_state->resource_mgr;
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
params.thread_pool = &unbounded_thread_pool_;
@ -530,7 +532,8 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
TF_RETURN_IF_ERROR(
LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref_iterator(iterator_resource);
return iterator_resource->SetIteratorFromDataset(ctx, dataset);
return iterator_resource->SetIteratorFromDataset(
ctx, dataset, /*epoch_id=*/IteratorContext::kNoEpochId);
}
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
@ -857,7 +860,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
// factory function.
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(
ctx, dataset, /*epoch_id=*/IteratorContext::kNoEpochId));
(*iterator)->Ref();
return Status::OK();
}

View File

@ -50,14 +50,33 @@ class IteratorResource : public ResourceBase {
~IteratorResource() override { VLOG(2) << "destructor"; }
// Gets the next output from the iterator managed by this iterator resource.
//
// If at least one output remains, that output will be stored in
// `*out_tensors` and `false` will be stored in `*end_of_sequence`.
//
// If no more outputs remain, `true` will be stored in `*end_of_sequence`, and
// the content of `*out_tensors` will be undefined.
Status GetNext(OpKernelContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence);
// Saves a checkpoint of the state of the iterator through the given `writer`.
Status Save(SerializationContext* ctx, IteratorStateWriter* writer);
// Restores the state of the iterator from a checkpoint created by `Save`.
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader);
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset);
// Creates an iterator for `dataset`, and associates the iterator with this
// iterator resource.
//
// The `epoch_id` will be passed through the IteratorContext when creating
// the iterator. This id is used by the tf.data service to determine which
// epoch to iterate through.
//
// `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
// or `Restore`.
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset,
int64 epoch_id);
string DebugString() const override { return "Iterator resource"; }

View File

@ -2797,3 +2797,6 @@ def filegroup_as_file(name, dep, visibility = []):
srcs = [name],
visibility = visibility,
)
def tf_grpc_cc_dependency():
return "//tensorflow:grpc++"