[tf.data service] Add master and worker proto definitions.
PiperOrigin-RevId: 301596660 Change-Id: Ibe7e271e345919ab58c1be61abec81b0463f972f
This commit is contained in:
parent
17f7d2ab69
commit
fe84fb19eb
tensorflow/core/data/service
58
tensorflow/core/data/service/BUILD
Normal file
58
tensorflow/core/data/service/BUILD
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
|
||||||
|
load(
|
||||||
|
"//tensorflow/core/platform:build_config.bzl",
|
||||||
|
"tf_additional_all_protos",
|
||||||
|
"tf_proto_library",
|
||||||
|
)
|
||||||
|
|
||||||
|
package(
|
||||||
|
default_visibility = [
|
||||||
|
"//tensorflow:internal",
|
||||||
|
],
|
||||||
|
licenses = ["notice"], # Apache 2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "common_proto",
|
||||||
|
srcs = ["common.proto"],
|
||||||
|
cc_api_version = 2,
|
||||||
|
protodeps = tf_additional_all_protos(),
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "master_proto",
|
||||||
|
srcs = ["master.proto"],
|
||||||
|
has_services = 1,
|
||||||
|
cc_api_version = 2,
|
||||||
|
protodeps = tf_additional_all_protos() + [
|
||||||
|
":common_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_proto_library(
|
||||||
|
name = "worker_proto",
|
||||||
|
srcs = ["worker.proto"],
|
||||||
|
has_services = 1,
|
||||||
|
cc_api_version = 2,
|
||||||
|
protodeps = tf_additional_all_protos() + [
|
||||||
|
":common_proto",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_grpc_library(
|
||||||
|
name = "master_cc_grpc_proto",
|
||||||
|
srcs = [":master_proto"],
|
||||||
|
generate_mocks = True,
|
||||||
|
grpc_only = True,
|
||||||
|
deps = [":master_proto_cc"],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_grpc_library(
|
||||||
|
name = "worker_cc_grpc_proto",
|
||||||
|
srcs = [":worker_proto"],
|
||||||
|
generate_mocks = True,
|
||||||
|
grpc_only = True,
|
||||||
|
deps = [":worker_proto_cc"],
|
||||||
|
)
|
40
tensorflow/core/data/service/common.proto
Normal file
40
tensorflow/core/data/service/common.proto
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow.data;
|
||||||
|
|
||||||
|
import "tensorflow/core/framework/graph.proto";
|
||||||
|
import "tensorflow/core/framework/tensor_shape.proto";
|
||||||
|
import "tensorflow/core/framework/types.proto";
|
||||||
|
|
||||||
|
message DatasetDef {
|
||||||
|
// We represent datasets as tensorflow GraphDefs which define the operations
|
||||||
|
// needed to create a tf.data dataset.
|
||||||
|
GraphDef graph = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ComponentMetadata {
|
||||||
|
// The dtype of the component tensor.
|
||||||
|
.tensorflow.DataType dtype = 1;
|
||||||
|
// The shape of the component tensor.
|
||||||
|
.tensorflow.TensorShapeProto tensor_shape = 2;
|
||||||
|
// Size of the uncompressed tensor bytes. For tensors serialized as
|
||||||
|
// TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
|
||||||
|
// this is the size of the buffer underlying the Tensor.
|
||||||
|
int64 tensor_size_bytes = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message CompressedElement {
|
||||||
|
// Compressed tensor bytes for all components of the element.
|
||||||
|
bytes data = 1;
|
||||||
|
// Metadata for the components of the element.
|
||||||
|
repeated ComponentMetadata component_metadata = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TaskDef {
|
||||||
|
// The dataset to iterate over.
|
||||||
|
// TODO(aaudibert): load the dataset from disk instead of passing it here.
|
||||||
|
DatasetDef dataset = 1;
|
||||||
|
int64 dataset_id = 2;
|
||||||
|
int64 task_id = 3;
|
||||||
|
int64 epoch_id = 4;
|
||||||
|
}
|
73
tensorflow/core/data/service/master.proto
Normal file
73
tensorflow/core/data/service/master.proto
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow.data;
|
||||||
|
|
||||||
|
import "tensorflow/core/data/service/common.proto";
|
||||||
|
|
||||||
|
message RegisterWorkerRequest {
|
||||||
|
// The address of the registering worker.
|
||||||
|
string worker_address = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RegisterWorkerResponse {
|
||||||
|
// An id for the worker.
|
||||||
|
int64 worker_id = 1;
|
||||||
|
// Tasks to begin processing.
|
||||||
|
repeated TaskDef tasks = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetOrRegisterDatasetRequest {
|
||||||
|
// The dataset to register.
|
||||||
|
DatasetDef dataset = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetOrRegisterDatasetResponse {
|
||||||
|
// The id for the registered dataset.
|
||||||
|
int64 dataset_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message BeginEpochRequest {
|
||||||
|
// The id of the dataset to iterate over.
|
||||||
|
int64 dataset_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message BeginEpochResponse {
|
||||||
|
// The id for the created epoch.
|
||||||
|
int64 epoch_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetTasksRequest {
|
||||||
|
// The epoch to look up tasks for.
|
||||||
|
int64 epoch_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message TaskInfo {
|
||||||
|
// The address of the worker processing the task.
|
||||||
|
string worker_address = 1;
|
||||||
|
// The task id.
|
||||||
|
int64 id = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetTasksResponse {
|
||||||
|
// A list of all tasks for an epoch.
|
||||||
|
repeated TaskInfo task_info = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
service MasterService {
|
||||||
|
// Registers a worker with the master.
|
||||||
|
rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
|
||||||
|
|
||||||
|
// Registers a dataset with the server, or returns its id if it is already
|
||||||
|
// registered.
|
||||||
|
//
|
||||||
|
// The dataset is constructed in a new graph, so it must not refer to
|
||||||
|
// external resources or variables.
|
||||||
|
rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
|
||||||
|
returns (GetOrRegisterDatasetResponse);
|
||||||
|
|
||||||
|
// Begins an epoch over a dataset.
|
||||||
|
rpc BeginEpoch(BeginEpochRequest) returns (BeginEpochResponse);
|
||||||
|
|
||||||
|
// Reports a list of all tasks for an epoch.
|
||||||
|
rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
|
||||||
|
}
|
31
tensorflow/core/data/service/worker.proto
Normal file
31
tensorflow/core/data/service/worker.proto
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
syntax = "proto3";
|
||||||
|
|
||||||
|
package tensorflow.data;
|
||||||
|
|
||||||
|
import "tensorflow/core/data/service/common.proto";
|
||||||
|
|
||||||
|
message ProcessTaskRequest {
|
||||||
|
TaskDef task = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ProcessTaskResponse {}
|
||||||
|
|
||||||
|
message GetElementRequest {
|
||||||
|
// The task to fetch an element from.
|
||||||
|
int64 task_id = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message GetElementResponse {
|
||||||
|
// The produced element.
|
||||||
|
CompressedElement compressed_element = 3;
|
||||||
|
// Boolean to indicate whether the iterator has been exhausted.
|
||||||
|
bool end_of_sequence = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
service WorkerService {
|
||||||
|
// Processes an task for a dataset, making elements available to clients.
|
||||||
|
rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
|
||||||
|
|
||||||
|
// Gets the next dataset element.
|
||||||
|
rpc GetElement(GetElementRequest) returns (GetElementResponse);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user