[tf.data service] Implement dataset registration and iteration creation ops.
PiperOrigin-RevId: 304264765 Change-Id: Iaaa3ea3f8e125f287b67a985bef4d8f8fb658803
This commit is contained in:
parent
4e6905a35a
commit
c1f9a95117
@ -4,6 +4,7 @@ load(
|
||||
"tf_additional_all_protos",
|
||||
"tf_proto_library",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"cc_header_only_library",
|
||||
@ -59,7 +60,6 @@ cc_library(
|
||||
":master_proto_cc",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -71,6 +71,7 @@ cc_library(
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -88,7 +89,6 @@ cc_library(
|
||||
":master_cc_grpc_proto",
|
||||
":master_proto_cc",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -100,6 +100,7 @@ cc_library(
|
||||
"//tensorflow/core/data:standalone",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/memory",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -110,8 +111,8 @@ cc_library(
|
||||
"grpc_util.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -165,9 +166,9 @@ cc_library(
|
||||
srcs = ["credentials_factory.cc"],
|
||||
hdrs = ["credentials_factory.h"],
|
||||
deps = [
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/strings",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -189,7 +190,7 @@ cc_library(
|
||||
srcs = ["local_credentials_factory.cc"],
|
||||
deps = [
|
||||
":credentials_factory",
|
||||
"//tensorflow:grpc++",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -231,8 +232,8 @@ cc_library(
|
||||
deps = [
|
||||
":master_cc_grpc_proto",
|
||||
":master_impl",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -243,8 +244,8 @@ cc_library(
|
||||
deps = [
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_impl",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core/distributed_runtime/rpc:grpc_util",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -266,9 +267,9 @@ cc_library(
|
||||
":credentials_factory",
|
||||
":grpc_master_impl",
|
||||
":grpc_worker_impl",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:tensorflow",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
@ -288,12 +289,12 @@ tf_cc_test(
|
||||
":test_util",
|
||||
":worker_cc_grpc_proto",
|
||||
":worker_proto_cc",
|
||||
"//tensorflow:grpc++",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels/data:dataset_test_base",
|
||||
"@com_google_absl//absl/strings",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -303,6 +303,10 @@ class Runner {
|
||||
// are doing is safe. We should formalize the properties here.
|
||||
class IteratorContext {
|
||||
public:
|
||||
// Epoch IDs are only used for tf.data service datasets. Other datasets use
|
||||
// an epoch ID value of -1.
|
||||
static constexpr const int64 kNoEpochId = -1;
|
||||
|
||||
struct Params {
|
||||
explicit Params(IteratorContext* ctx)
|
||||
: allocator_getter(ctx->allocator_getter()),
|
||||
@ -310,6 +314,7 @@ class IteratorContext {
|
||||
env(ctx->env()),
|
||||
flr(ctx->flr()),
|
||||
function_handle_cache(ctx->function_handle_cache()),
|
||||
epoch_id(ctx->epoch_id()),
|
||||
resource_mgr(ctx->resource_mgr()),
|
||||
model(ctx->model()),
|
||||
runner(*(ctx->runner())),
|
||||
@ -366,6 +371,10 @@ class IteratorContext {
|
||||
// A FunctionHandleCache that owns all the function handles. Not owned.
|
||||
FunctionHandleCache* function_handle_cache = nullptr;
|
||||
|
||||
// Identifies the epoch this iterator was created for. It is used for
|
||||
// reading from the tf.data service.
|
||||
int64 epoch_id = kNoEpochId;
|
||||
|
||||
// A resource manager for storing dataset-related state, e.g. random
|
||||
// seeds or cached tensors. Not owned.
|
||||
ResourceMgr* resource_mgr = nullptr;
|
||||
@ -415,6 +424,8 @@ class IteratorContext {
|
||||
return params_.function_handle_cache;
|
||||
}
|
||||
|
||||
int64 epoch_id() { return params_.epoch_id; }
|
||||
|
||||
ResourceMgr* resource_mgr() { return params_.resource_mgr; }
|
||||
|
||||
const std::shared_ptr<model::Model>& model() { return params_.model; }
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Description:
|
||||
# Contains experimental kernels for datasets and iterators.
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_grpc_cc_dependency")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
@ -120,6 +121,25 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "data_service_ops",
|
||||
srcs = ["data_service_ops.cc"],
|
||||
hdrs = ["data_service_ops.h"],
|
||||
deps = [
|
||||
"//tensorflow/core:experimental_dataset_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/data/service:credentials_factory",
|
||||
"//tensorflow/core/data/service:grpc_util",
|
||||
"//tensorflow/core/data/service:master_cc_grpc_proto",
|
||||
"//tensorflow/core/data/service:master_proto_cc",
|
||||
"//tensorflow/core/kernels/data:dataset_utils",
|
||||
"//tensorflow/core/kernels/data:iterator_ops",
|
||||
tf_grpc_cc_dependency(),
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "dense_to_sparse_batch_dataset_op",
|
||||
srcs = ["dense_to_sparse_batch_dataset_op.cc"],
|
||||
|
141
tensorflow/core/kernels/data/experimental/data_service_ops.cc
Normal file
141
tensorflow/core/kernels/data/experimental/data_service_ops.cc
Normal file
@ -0,0 +1,141 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/data/experimental/data_service_ops.h"
|
||||
|
||||
#include "grpcpp/create_channel.h"
|
||||
#include "grpcpp/security/credentials.h"
|
||||
#include "tensorflow/core/data/service/credentials_factory.h"
|
||||
#include "tensorflow/core/data/service/grpc_util.h"
|
||||
#include "tensorflow/core/data/service/master.grpc.pb.h"
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/kernels/data/dataset_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx)
|
||||
: OpKernel(ctx) {
|
||||
int64 external_state_policy_int;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, ctx->GetAttr(kExternalStatePolicy, &external_state_policy_int));
|
||||
external_state_policy_ =
|
||||
SerializationContext::ExternalStatePolicy(external_state_policy_int);
|
||||
}
|
||||
|
||||
void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
|
||||
DatasetBase* dataset;
|
||||
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
tstring address;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
|
||||
OP_REQUIRES(ctx, !address.empty(),
|
||||
errors::InvalidArgument(kAddress, " must be non-empty."));
|
||||
|
||||
tstring protocol;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
|
||||
OP_REQUIRES(ctx, !protocol.empty(),
|
||||
errors::InvalidArgument(kProtocol, " must be non-empty."));
|
||||
|
||||
SerializationContext::Params params;
|
||||
params.external_state_policy = external_state_policy_;
|
||||
SerializationContext serialization_ctx(params);
|
||||
GraphDef graph_def;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, AsGraphDef(ctx, dataset, std::move(serialization_ctx), &graph_def));
|
||||
|
||||
// ::grpc::ChannelArguments args;
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CredentialsFactory::CreateClientCredentials(protocol, &credentials));
|
||||
auto channel = ::grpc::CreateChannel(address, credentials);
|
||||
auto master_stub = MasterService::NewStub(channel);
|
||||
GetOrRegisterDatasetRequest req;
|
||||
*req.mutable_dataset()->mutable_graph() = graph_def;
|
||||
GetOrRegisterDatasetResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
auto status = master_stub->GetOrRegisterDataset(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
ctx->CtxFailure(grpc_util::WrapError("Failed to register dataset", status));
|
||||
return;
|
||||
}
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
|
||||
auto output_dataset_id = output->tensor<int64, 0>();
|
||||
output_dataset_id() = resp.dataset_id();
|
||||
}
|
||||
|
||||
BeginEpochOp::BeginEpochOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
|
||||
|
||||
void BeginEpochOp::Compute(OpKernelContext* ctx) {
|
||||
int64 dataset_id;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kDatasetId, &dataset_id));
|
||||
|
||||
tstring address;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kAddress, &address));
|
||||
OP_REQUIRES(ctx, !address.empty(),
|
||||
errors::InvalidArgument(kAddress, " must be non-empty."));
|
||||
|
||||
tstring protocol;
|
||||
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, kProtocol, &protocol));
|
||||
OP_REQUIRES(ctx, !protocol.empty(),
|
||||
errors::InvalidArgument(kProtocol, " must be non-empty."));
|
||||
|
||||
std::shared_ptr<::grpc::ChannelCredentials> credentials;
|
||||
OP_REQUIRES_OK(
|
||||
ctx, CredentialsFactory::CreateClientCredentials(protocol, &credentials));
|
||||
auto channel = ::grpc::CreateChannel(address, credentials);
|
||||
auto master_stub = MasterService::NewStub(channel);
|
||||
BeginEpochRequest req;
|
||||
req.set_dataset_id(dataset_id);
|
||||
BeginEpochResponse resp;
|
||||
grpc::ClientContext client_ctx;
|
||||
auto status = master_stub->BeginEpoch(&client_ctx, req, &resp);
|
||||
if (!status.ok()) {
|
||||
ctx->CtxFailure(grpc_util::WrapError(
|
||||
absl::StrCat("Failed to begin epoch for dataset id ", dataset_id),
|
||||
status));
|
||||
return;
|
||||
}
|
||||
Tensor* output;
|
||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape{}, &output));
|
||||
auto output_epoch_id = output->tensor<int64, 0>();
|
||||
output_epoch_id() = resp.epoch_id();
|
||||
}
|
||||
|
||||
Status MakeDataServiceIteratorOp::DoCompute(OpKernelContext* ctx) {
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(ctx->input(0), &dataset));
|
||||
|
||||
const Tensor* epoch_id_tensor;
|
||||
TF_RETURN_IF_ERROR(ctx->input(kEpochId, &epoch_id_tensor));
|
||||
int64 epoch_id = epoch_id_tensor->scalar<int64>()();
|
||||
|
||||
IteratorResource* iterator_resource;
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupResource(ctx, HandleFromInput(ctx, 2), &iterator_resource));
|
||||
|
||||
core::ScopedUnref unref_iterator(iterator_resource);
|
||||
return iterator_resource->SetIteratorFromDataset(ctx, dataset, epoch_id);
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("RegisterDataset").Device(DEVICE_CPU),
|
||||
RegisterDatasetOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("BeginEpoch").Device(DEVICE_CPU), BeginEpochOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("MakeDataServiceIterator").Device(DEVICE_CPU),
|
||||
MakeDataServiceIteratorOp);
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
82
tensorflow/core/kernels/data/experimental/data_service_ops.h
Normal file
82
tensorflow/core/kernels/data/experimental/data_service_ops.h
Normal file
@ -0,0 +1,82 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
|
||||
|
||||
#include "tensorflow/core/framework/dataset.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/kernels/data/iterator_ops.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace data {
|
||||
|
||||
// Registers a dataset with the tf.data service.
|
||||
//
|
||||
// The address and protocol inputs are used to connect to the tf.data master.
|
||||
// The external state policy attribute determines whether to ignore, warn, or
|
||||
// error out when the dataset contains external state.
|
||||
// The op produces a dataset id for identifying the registered dataset.
|
||||
class RegisterDatasetOp : public OpKernel {
|
||||
public:
|
||||
static constexpr const char* const kAddress = "address";
|
||||
static constexpr const char* const kProtocol = "protocol";
|
||||
static constexpr const char* const kExternalStatePolicy =
|
||||
"external_state_policy";
|
||||
|
||||
explicit RegisterDatasetOp(OpKernelConstruction* ctx);
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
SerializationContext::ExternalStatePolicy external_state_policy_;
|
||||
};
|
||||
|
||||
// Begins a new epoch for a tf.data service dataset.
|
||||
//
|
||||
// The dataset_id input identifies which dataset to start a new epoch for.
|
||||
// The address and protocol inputs are used to connect to the tf.data service
|
||||
// master.
|
||||
// The op produces an epoch id to identify the newly created epoch.
|
||||
class BeginEpochOp : public OpKernel {
|
||||
public:
|
||||
static constexpr const char* const kDatasetId = "dataset_id";
|
||||
static constexpr const char* const kAddress = "address";
|
||||
static constexpr const char* const kProtocol = "protocol";
|
||||
|
||||
explicit BeginEpochOp(OpKernelConstruction* ctx);
|
||||
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
};
|
||||
|
||||
// Creates a new iterator for iterating over a tf.data service dataset.
|
||||
//
|
||||
// The epoch_id input identifies which epoch to read from. Multiple iterators
|
||||
// may read from the same epoch, causing the elements of the epoch to be split
|
||||
// across all iterators.
|
||||
class MakeDataServiceIteratorOp : public MakeIteratorOp {
|
||||
public:
|
||||
static constexpr const char* const kEpochId = "epoch_id";
|
||||
|
||||
explicit MakeDataServiceIteratorOp(OpKernelConstruction* ctx)
|
||||
: MakeIteratorOp(ctx) {}
|
||||
|
||||
protected:
|
||||
Status DoCompute(OpKernelContext* ctx) override;
|
||||
};
|
||||
|
||||
} // namespace data
|
||||
} // namespace tensorflow
|
||||
#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_DATA_SERVICE_OPS_H_
|
@ -166,7 +166,8 @@ Status IteratorResource::Restore(OpKernelContext* ctx,
|
||||
}
|
||||
|
||||
Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||
DatasetBase* dataset) {
|
||||
DatasetBase* dataset,
|
||||
int64 epoch_id) {
|
||||
std::shared_ptr<State> new_state;
|
||||
{
|
||||
tf_shared_lock l(mu_);
|
||||
@ -179,6 +180,7 @@ Status IteratorResource::SetIteratorFromDataset(OpKernelContext* ctx,
|
||||
IteratorContext::Params params(ctx);
|
||||
params.flr = new_state->flr;
|
||||
params.function_handle_cache = new_state->function_handle_cache.get();
|
||||
params.epoch_id = epoch_id;
|
||||
params.resource_mgr = &new_state->resource_mgr;
|
||||
params.thread_factory = unbounded_thread_pool_.get_thread_factory();
|
||||
params.thread_pool = &unbounded_thread_pool_;
|
||||
@ -530,7 +532,8 @@ Status MakeIteratorOp::DoCompute(OpKernelContext* ctx) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
|
||||
core::ScopedUnref unref_iterator(iterator_resource);
|
||||
return iterator_resource->SetIteratorFromDataset(ctx, dataset);
|
||||
return iterator_resource->SetIteratorFromDataset(
|
||||
ctx, dataset, /*epoch_id=*/IteratorContext::kNoEpochId);
|
||||
}
|
||||
|
||||
void DeleteIteratorOp::Compute(OpKernelContext* ctx) {
|
||||
@ -857,7 +860,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
|
||||
// factory function.
|
||||
DatasetBase* dataset;
|
||||
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
|
||||
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(ctx, dataset));
|
||||
TF_RETURN_IF_ERROR((*iterator)->SetIteratorFromDataset(
|
||||
ctx, dataset, /*epoch_id=*/IteratorContext::kNoEpochId));
|
||||
(*iterator)->Ref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -50,14 +50,33 @@ class IteratorResource : public ResourceBase {
|
||||
|
||||
~IteratorResource() override { VLOG(2) << "destructor"; }
|
||||
|
||||
// Gets the next output from the iterator managed by this iterator resource.
|
||||
//
|
||||
// If at least one output remains, that output will be stored in
|
||||
// `*out_tensors` and `false` will be stored in `*end_of_sequence`.
|
||||
//
|
||||
// If no more outputs remain, `true` will be stored in `*end_of_sequence`, and
|
||||
// the content of `*out_tensors` will be undefined.
|
||||
Status GetNext(OpKernelContext* ctx, std::vector<Tensor>* out_tensors,
|
||||
bool* end_of_sequence);
|
||||
|
||||
// Saves a checkpoint of the state of the iterator through the given `writer`.
|
||||
Status Save(SerializationContext* ctx, IteratorStateWriter* writer);
|
||||
|
||||
// Restores the state of the iterator from a checkpoint created by `Save`.
|
||||
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader);
|
||||
|
||||
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset);
|
||||
// Creates an iterator for `dataset`, and associates the iterator with this
|
||||
// iterator resource.
|
||||
//
|
||||
// The `epoch_id` will be passed through the IteratorContext when creating
|
||||
// the iterator. This id is used by the tf.data service to determine which
|
||||
// epoch to iterate through.
|
||||
//
|
||||
// `SetIteratorFromDataset` should be called before calling `GetNext`, `Save`,
|
||||
// or `Restore`.
|
||||
Status SetIteratorFromDataset(OpKernelContext* ctx, DatasetBase* dataset,
|
||||
int64 epoch_id);
|
||||
|
||||
string DebugString() const override { return "Iterator resource"; }
|
||||
|
||||
|
@ -2797,3 +2797,6 @@ def filegroup_as_file(name, dep, visibility = []):
|
||||
srcs = [name],
|
||||
visibility = visibility,
|
||||
)
|
||||
|
||||
def tf_grpc_cc_dependency():
|
||||
return "//tensorflow:grpc++"
|
||||
|
Loading…
Reference in New Issue
Block a user