[tf.data service] Add master and worker proto definitions.
PiperOrigin-RevId: 301596660 Change-Id: Ibe7e271e345919ab58c1be61abec81b0463f972f
This commit is contained in:
parent
17f7d2ab69
commit
fe84fb19eb
tensorflow/core/data/service
58
tensorflow/core/data/service/BUILD
Normal file
58
tensorflow/core/data/service/BUILD
Normal file
@ -0,0 +1,58 @@
|
||||
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_additional_all_protos",
|
||||
"tf_proto_library",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
|
||||
tf_proto_library(
|
||||
name = "common_proto",
|
||||
srcs = ["common.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "master_proto",
|
||||
srcs = ["master.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos() + [
|
||||
":common_proto",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "worker_proto",
|
||||
srcs = ["worker.proto"],
|
||||
has_services = 1,
|
||||
cc_api_version = 2,
|
||||
protodeps = tf_additional_all_protos() + [
|
||||
":common_proto",
|
||||
],
|
||||
)
|
||||
|
||||
cc_grpc_library(
|
||||
name = "master_cc_grpc_proto",
|
||||
srcs = [":master_proto"],
|
||||
generate_mocks = True,
|
||||
grpc_only = True,
|
||||
deps = [":master_proto_cc"],
|
||||
)
|
||||
|
||||
cc_grpc_library(
|
||||
name = "worker_cc_grpc_proto",
|
||||
srcs = [":worker_proto"],
|
||||
generate_mocks = True,
|
||||
grpc_only = True,
|
||||
deps = [":worker_proto_cc"],
|
||||
)
|
40
tensorflow/core/data/service/common.proto
Normal file
40
tensorflow/core/data/service/common.proto
Normal file
@ -0,0 +1,40 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/framework/graph.proto";
|
||||
import "tensorflow/core/framework/tensor_shape.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
|
||||
message DatasetDef {
|
||||
// We represent datasets as tensorflow GraphDefs which define the operations
|
||||
// needed to create a tf.data dataset.
|
||||
GraphDef graph = 1;
|
||||
}
|
||||
|
||||
message ComponentMetadata {
|
||||
// The dtype of the component tensor.
|
||||
.tensorflow.DataType dtype = 1;
|
||||
// The shape of the component tensor.
|
||||
.tensorflow.TensorShapeProto tensor_shape = 2;
|
||||
// Size of the uncompressed tensor bytes. For tensors serialized as
|
||||
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
|
||||
// this is the size of the buffer underlying the Tensor.
|
||||
int64 tensor_size_bytes = 3;
|
||||
}
|
||||
|
||||
message CompressedElement {
|
||||
// Compressed tensor bytes for all components of the element.
|
||||
bytes data = 1;
|
||||
// Metadata for the components of the element.
|
||||
repeated ComponentMetadata component_metadata = 2;
|
||||
}
|
||||
|
||||
message TaskDef {
|
||||
// The dataset to iterate over.
|
||||
// TODO(aaudibert): load the dataset from disk instead of passing it here.
|
||||
DatasetDef dataset = 1;
|
||||
int64 dataset_id = 2;
|
||||
int64 task_id = 3;
|
||||
int64 epoch_id = 4;
|
||||
}
|
73
tensorflow/core/data/service/master.proto
Normal file
73
tensorflow/core/data/service/master.proto
Normal file
@ -0,0 +1,73 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/data/service/common.proto";
|
||||
|
||||
message RegisterWorkerRequest {
|
||||
// The address of the registering worker.
|
||||
string worker_address = 1;
|
||||
}
|
||||
|
||||
message RegisterWorkerResponse {
|
||||
// An id for the worker.
|
||||
int64 worker_id = 1;
|
||||
// Tasks to begin processing.
|
||||
repeated TaskDef tasks = 2;
|
||||
}
|
||||
|
||||
message GetOrRegisterDatasetRequest {
|
||||
// The dataset to register.
|
||||
DatasetDef dataset = 1;
|
||||
}
|
||||
|
||||
message GetOrRegisterDatasetResponse {
|
||||
// The id for the registered dataset.
|
||||
int64 dataset_id = 1;
|
||||
}
|
||||
|
||||
message BeginEpochRequest {
|
||||
// The id of the dataset to iterate over.
|
||||
int64 dataset_id = 1;
|
||||
}
|
||||
|
||||
message BeginEpochResponse {
|
||||
// The id for the created epoch.
|
||||
int64 epoch_id = 1;
|
||||
}
|
||||
|
||||
message GetTasksRequest {
|
||||
// The epoch to look up tasks for.
|
||||
int64 epoch_id = 1;
|
||||
}
|
||||
|
||||
message TaskInfo {
|
||||
// The address of the worker processing the task.
|
||||
string worker_address = 1;
|
||||
// The task id.
|
||||
int64 id = 2;
|
||||
}
|
||||
|
||||
message GetTasksResponse {
|
||||
// A list of all tasks for an epoch.
|
||||
repeated TaskInfo task_info = 1;
|
||||
}
|
||||
|
||||
service MasterService {
|
||||
// Registers a worker with the master.
|
||||
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
||||
|
||||
// Registers a dataset with the server, or returns its id if it is already
|
||||
// registered.
|
||||
//
|
||||
// The dataset is constructed in a new graph, so it must not refer to
|
||||
// external resources or variables.
|
||||
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
|
||||
returns (GetOrRegisterDatasetResponse);
|
||||
|
||||
// Begins an epoch over a dataset.
|
||||
rpc BeginEpoch(BeginEpochRequest) returns (BeginEpochResponse);
|
||||
|
||||
// Reports a list of all tasks for an epoch.
|
||||
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
||||
}
|
31
tensorflow/core/data/service/worker.proto
Normal file
31
tensorflow/core/data/service/worker.proto
Normal file
@ -0,0 +1,31 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow.data;
|
||||
|
||||
import "tensorflow/core/data/service/common.proto";
|
||||
|
||||
message ProcessTaskRequest {
|
||||
TaskDef task = 1;
|
||||
}
|
||||
|
||||
message ProcessTaskResponse {}
|
||||
|
||||
message GetElementRequest {
|
||||
// The task to fetch an element from.
|
||||
int64 task_id = 1;
|
||||
}
|
||||
|
||||
message GetElementResponse {
|
||||
// The produced element.
|
||||
CompressedElement compressed_element = 3;
|
||||
// Boolean to indicate whether the iterator has been exhausted.
|
||||
bool end_of_sequence = 2;
|
||||
}
|
||||
|
||||
service WorkerService {
|
||||
// Processes an task for a dataset, making elements available to clients.
|
||||
rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
|
||||
|
||||
// Gets the next dataset element.
|
||||
rpc GetElement(GetElementRequest) returns (GetElementResponse);
|
||||
}
|
Loading…
Reference in New Issue
Block a user