From fe84fb19eb0df509a5efd2017641ab557e70ee6e Mon Sep 17 00:00:00 2001 From: Andrew Audibert <aaudibert@google.com> Date: Wed, 18 Mar 2020 08:44:12 -0700 Subject: [PATCH] [tf.data service] Add master and worker proto definitions. PiperOrigin-RevId: 301596660 Change-Id: Ibe7e271e345919ab58c1be61abec81b0463f972f --- tensorflow/core/data/service/BUILD | 58 ++++++++++++++++++ tensorflow/core/data/service/common.proto | 40 +++++++++++++ tensorflow/core/data/service/master.proto | 73 +++++++++++++++++++++++ tensorflow/core/data/service/worker.proto | 31 ++++++++++ 4 files changed, 202 insertions(+) create mode 100644 tensorflow/core/data/service/BUILD create mode 100644 tensorflow/core/data/service/common.proto create mode 100644 tensorflow/core/data/service/master.proto create mode 100644 tensorflow/core/data/service/worker.proto diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD new file mode 100644 index 00000000000..6003362406f --- /dev/null +++ b/tensorflow/core/data/service/BUILD @@ -0,0 +1,58 @@ +load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library") +load( + "//tensorflow/core/platform:build_config.bzl", + "tf_additional_all_protos", + "tf_proto_library", +) + +package( + default_visibility = [ + "//tensorflow:internal", + ], + licenses = ["notice"], # Apache 2.0 +) + +exports_files(["LICENSE"]) + +tf_proto_library( + name = "common_proto", + srcs = ["common.proto"], + cc_api_version = 2, + protodeps = tf_additional_all_protos(), +) + +tf_proto_library( + name = "master_proto", + srcs = ["master.proto"], + has_services = 1, + cc_api_version = 2, + protodeps = tf_additional_all_protos() + [ + ":common_proto", + ], +) + +tf_proto_library( + name = "worker_proto", + srcs = ["worker.proto"], + has_services = 1, + cc_api_version = 2, + protodeps = tf_additional_all_protos() + [ + ":common_proto", + ], +) + +cc_grpc_library( + name = "master_cc_grpc_proto", + srcs = [":master_proto"], + generate_mocks = True, + grpc_only = True, + deps = [":master_proto_cc"], +) + +cc_grpc_library( + name = "worker_cc_grpc_proto", + srcs = [":worker_proto"], + generate_mocks = True, + grpc_only = True, + deps = [":worker_proto_cc"], +) diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto new file mode 100644 index 00000000000..0faaa661e08 --- /dev/null +++ b/tensorflow/core/data/service/common.proto @@ -0,0 +1,40 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/framework/tensor_shape.proto"; +import "tensorflow/core/framework/types.proto"; + +message DatasetDef { + // We represent datasets as tensorflow GraphDefs which define the operations + // needed to create a tf.data dataset. + GraphDef graph = 1; +} + +message ComponentMetadata { + // The dtype of the component tensor. + .tensorflow.DataType dtype = 1; + // The shape of the component tensor. + .tensorflow.TensorShapeProto tensor_shape = 2; + // Size of the uncompressed tensor bytes. For tensors serialized as + // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors, + // this is the size of the buffer underlying the Tensor. + int64 tensor_size_bytes = 3; +} + +message CompressedElement { + // Compressed tensor bytes for all components of the element. + bytes data = 1; + // Metadata for the components of the element. + repeated ComponentMetadata component_metadata = 2; +} + +message TaskDef { + // The dataset to iterate over. + // TODO(aaudibert): load the dataset from disk instead of passing it here. + DatasetDef dataset = 1; + int64 dataset_id = 2; + int64 task_id = 3; + int64 epoch_id = 4; +} diff --git a/tensorflow/core/data/service/master.proto b/tensorflow/core/data/service/master.proto new file mode 100644 index 00000000000..03be51c79e7 --- /dev/null +++ b/tensorflow/core/data/service/master.proto @@ -0,0 +1,73 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/data/service/common.proto"; + +message RegisterWorkerRequest { + // The address of the registering worker. + string worker_address = 1; +} + +message RegisterWorkerResponse { + // An id for the worker. + int64 worker_id = 1; + // Tasks to begin processing. + repeated TaskDef tasks = 2; +} + +message GetOrRegisterDatasetRequest { + // The dataset to register. + DatasetDef dataset = 1; +} + +message GetOrRegisterDatasetResponse { + // The id for the registered dataset. + int64 dataset_id = 1; +} + +message BeginEpochRequest { + // The id of the dataset to iterate over. + int64 dataset_id = 1; +} + +message BeginEpochResponse { + // The id for the created epoch. + int64 epoch_id = 1; +} + +message GetTasksRequest { + // The epoch to look up tasks for. + int64 epoch_id = 1; +} + +message TaskInfo { + // The address of the worker processing the task. + string worker_address = 1; + // The task id. + int64 id = 2; +} + +message GetTasksResponse { + // A list of all tasks for an epoch. + repeated TaskInfo task_info = 1; +} + +service MasterService { + // Registers a worker with the master. + rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse); + + // Registers a dataset with the server, or returns its id if it is already + // registered. + // + // The dataset is constructed in a new graph, so it must not refer to + // external resources or variables. + rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest) + returns (GetOrRegisterDatasetResponse); + + // Begins an epoch over a dataset. + rpc BeginEpoch(BeginEpochRequest) returns (BeginEpochResponse); + + // Reports a list of all tasks for an epoch. + rpc GetTasks(GetTasksRequest) returns (GetTasksResponse); +} diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto new file mode 100644 index 00000000000..04b8f03474c --- /dev/null +++ b/tensorflow/core/data/service/worker.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +package tensorflow.data; + +import "tensorflow/core/data/service/common.proto"; + +message ProcessTaskRequest { + TaskDef task = 1; +} + +message ProcessTaskResponse {} + +message GetElementRequest { + // The task to fetch an element from. + int64 task_id = 1; +} + +message GetElementResponse { + // The produced element. + CompressedElement compressed_element = 3; + // Boolean to indicate whether the iterator has been exhausted. + bool end_of_sequence = 2; +} + +service WorkerService { + // Processes an task for a dataset, making elements available to clients. + rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse); + + // Gets the next dataset element. + rpc GetElement(GetElementRequest) returns (GetElementResponse); +}