[tf.data service] Add master and worker proto definitions.

PiperOrigin-RevId: 301596660
Change-Id: Ibe7e271e345919ab58c1be61abec81b0463f972f
This commit is contained in:
Andrew Audibert 2020-03-18 08:44:12 -07:00 committed by TensorFlower Gardener
parent 17f7d2ab69
commit fe84fb19eb
4 changed files with 202 additions and 0 deletions
tensorflow/core/data/service

View File

@ -0,0 +1,58 @@
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_additional_all_protos",
"tf_proto_library",
)
package(
default_visibility = [
"//tensorflow:internal",
],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
tf_proto_library(
name = "common_proto",
srcs = ["common.proto"],
cc_api_version = 2,
protodeps = tf_additional_all_protos(),
)
tf_proto_library(
name = "master_proto",
srcs = ["master.proto"],
has_services = 1,
cc_api_version = 2,
protodeps = tf_additional_all_protos() + [
":common_proto",
],
)
tf_proto_library(
name = "worker_proto",
srcs = ["worker.proto"],
has_services = 1,
cc_api_version = 2,
protodeps = tf_additional_all_protos() + [
":common_proto",
],
)
cc_grpc_library(
name = "master_cc_grpc_proto",
srcs = [":master_proto"],
generate_mocks = True,
grpc_only = True,
deps = [":master_proto_cc"],
)
cc_grpc_library(
name = "worker_cc_grpc_proto",
srcs = [":worker_proto"],
generate_mocks = True,
grpc_only = True,
deps = [":worker_proto_cc"],
)

View File

@ -0,0 +1,40 @@
syntax = "proto3";
package tensorflow.data;
import "tensorflow/core/framework/graph.proto";
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
message DatasetDef {
// We represent datasets as tensorflow GraphDefs which define the operations
// needed to create a tf.data dataset.
GraphDef graph = 1;
}
message ComponentMetadata {
// The dtype of the component tensor.
.tensorflow.DataType dtype = 1;
// The shape of the component tensor.
.tensorflow.TensorShapeProto tensor_shape = 2;
// Size of the uncompressed tensor bytes. For tensors serialized as
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
// this is the size of the buffer underlying the Tensor.
int64 tensor_size_bytes = 3;
}
message CompressedElement {
// Compressed tensor bytes for all components of the element.
bytes data = 1;
// Metadata for the components of the element.
repeated ComponentMetadata component_metadata = 2;
}
message TaskDef {
// The dataset to iterate over.
// TODO(aaudibert): load the dataset from disk instead of passing it here.
DatasetDef dataset = 1;
int64 dataset_id = 2;
int64 task_id = 3;
int64 epoch_id = 4;
}

View File

@ -0,0 +1,73 @@
syntax = "proto3";
package tensorflow.data;
import "tensorflow/core/data/service/common.proto";
message RegisterWorkerRequest {
// The address of the registering worker.
string worker_address = 1;
}
message RegisterWorkerResponse {
// An id for the worker.
int64 worker_id = 1;
// Tasks to begin processing.
repeated TaskDef tasks = 2;
}
message GetOrRegisterDatasetRequest {
// The dataset to register.
DatasetDef dataset = 1;
}
message GetOrRegisterDatasetResponse {
// The id for the registered dataset.
int64 dataset_id = 1;
}
message BeginEpochRequest {
// The id of the dataset to iterate over.
int64 dataset_id = 1;
}
message BeginEpochResponse {
// The id for the created epoch.
int64 epoch_id = 1;
}
message GetTasksRequest {
// The epoch to look up tasks for.
int64 epoch_id = 1;
}
message TaskInfo {
// The address of the worker processing the task.
string worker_address = 1;
// The task id.
int64 id = 2;
}
message GetTasksResponse {
// A list of all tasks for an epoch.
repeated TaskInfo task_info = 1;
}
service MasterService {
// Registers a worker with the master.
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
// Registers a dataset with the server, or returns its id if it is already
// registered.
//
// The dataset is constructed in a new graph, so it must not refer to
// external resources or variables.
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
returns (GetOrRegisterDatasetResponse);
// Begins an epoch over a dataset.
rpc BeginEpoch(BeginEpochRequest) returns (BeginEpochResponse);
// Reports a list of all tasks for an epoch.
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
}

View File

@ -0,0 +1,31 @@
syntax = "proto3";
package tensorflow.data;
import "tensorflow/core/data/service/common.proto";
message ProcessTaskRequest {
TaskDef task = 1;
}
message ProcessTaskResponse {}
message GetElementRequest {
// The task to fetch an element from.
int64 task_id = 1;
}
message GetElementResponse {
// The produced element.
CompressedElement compressed_element = 3;
// Boolean to indicate whether the iterator has been exhausted.
bool end_of_sequence = 2;
}
service WorkerService {
// Processes an task for a dataset, making elements available to clients.
rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
// Gets the next dataset element.
rpc GetElement(GetElementRequest) returns (GetElementResponse);
}