From fe84fb19eb0df509a5efd2017641ab557e70ee6e Mon Sep 17 00:00:00 2001
From: Andrew Audibert <aaudibert@google.com>
Date: Wed, 18 Mar 2020 08:44:12 -0700
Subject: [PATCH] [tf.data service] Add master and worker proto definitions.

PiperOrigin-RevId: 301596660
Change-Id: Ibe7e271e345919ab58c1be61abec81b0463f972f
---
 tensorflow/core/data/service/BUILD        | 58 ++++++++++++++++++
 tensorflow/core/data/service/common.proto | 40 +++++++++++++
 tensorflow/core/data/service/master.proto | 73 +++++++++++++++++++++++
 tensorflow/core/data/service/worker.proto | 31 ++++++++++
 4 files changed, 202 insertions(+)
 create mode 100644 tensorflow/core/data/service/BUILD
 create mode 100644 tensorflow/core/data/service/common.proto
 create mode 100644 tensorflow/core/data/service/master.proto
 create mode 100644 tensorflow/core/data/service/worker.proto

diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
new file mode 100644
index 00000000000..6003362406f
--- /dev/null
+++ b/tensorflow/core/data/service/BUILD
@@ -0,0 +1,58 @@
+load("@com_github_grpc_grpc//bazel:cc_grpc_library.bzl", "cc_grpc_library")
+load(
+    "//tensorflow/core/platform:build_config.bzl",
+    "tf_additional_all_protos",
+    "tf_proto_library",
+)
+
+package(
+    default_visibility = [
+        "//tensorflow:internal",
+    ],
+    licenses = ["notice"],  # Apache 2.0
+)
+
+exports_files(["LICENSE"])
+
+tf_proto_library(
+    name = "common_proto",
+    srcs = ["common.proto"],
+    cc_api_version = 2,
+    protodeps = tf_additional_all_protos(),
+)
+
+tf_proto_library(
+    name = "master_proto",
+    srcs = ["master.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    protodeps = tf_additional_all_protos() + [
+        ":common_proto",
+    ],
+)
+
+tf_proto_library(
+    name = "worker_proto",
+    srcs = ["worker.proto"],
+    has_services = 1,
+    cc_api_version = 2,
+    protodeps = tf_additional_all_protos() + [
+        ":common_proto",
+    ],
+)
+
+cc_grpc_library(
+    name = "master_cc_grpc_proto",
+    srcs = [":master_proto"],
+    generate_mocks = True,
+    grpc_only = True,
+    deps = [":master_proto_cc"],
+)
+
+cc_grpc_library(
+    name = "worker_cc_grpc_proto",
+    srcs = [":worker_proto"],
+    generate_mocks = True,
+    grpc_only = True,
+    deps = [":worker_proto_cc"],
+)
diff --git a/tensorflow/core/data/service/common.proto b/tensorflow/core/data/service/common.proto
new file mode 100644
index 00000000000..0faaa661e08
--- /dev/null
+++ b/tensorflow/core/data/service/common.proto
@@ -0,0 +1,40 @@
+syntax = "proto3";
+
+package tensorflow.data;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/framework/tensor_shape.proto";
+import "tensorflow/core/framework/types.proto";
+
+message DatasetDef {
+  // We represent datasets as tensorflow GraphDefs which define the operations
+  // needed to create a tf.data dataset.
+  GraphDef graph = 1;
+}
+
+message ComponentMetadata {
+  // The dtype of the component tensor.
+  .tensorflow.DataType dtype = 1;
+  // The shape of the component tensor.
+  .tensorflow.TensorShapeProto tensor_shape = 2;
+  // Size of the uncompressed tensor bytes. For tensors serialized as
+  // TensorProtos, this is TensorProto::BytesAllocatedLong(). For raw Tensors,
+  // this is the size of the buffer underlying the Tensor.
+  int64 tensor_size_bytes = 3;
+}
+
+message CompressedElement {
+  // Compressed tensor bytes for all components of the element.
+  bytes data = 1;
+  // Metadata for the components of the element.
+  repeated ComponentMetadata component_metadata = 2;
+}
+
+message TaskDef {
+  // The dataset to iterate over.
+  // TODO(aaudibert): load the dataset from disk instead of passing it here.
+  DatasetDef dataset = 1;
+  int64 dataset_id = 2;
+  int64 task_id = 3;
+  int64 epoch_id = 4;
+}
diff --git a/tensorflow/core/data/service/master.proto b/tensorflow/core/data/service/master.proto
new file mode 100644
index 00000000000..03be51c79e7
--- /dev/null
+++ b/tensorflow/core/data/service/master.proto
@@ -0,0 +1,73 @@
+syntax = "proto3";
+
+package tensorflow.data;
+
+import "tensorflow/core/data/service/common.proto";
+
+message RegisterWorkerRequest {
+  // The address of the registering worker.
+  string worker_address = 1;
+}
+
+message RegisterWorkerResponse {
+  // An id for the worker.
+  int64 worker_id = 1;
+  // Tasks to begin processing.
+  repeated TaskDef tasks = 2;
+}
+
+message GetOrRegisterDatasetRequest {
+  // The dataset to register.
+  DatasetDef dataset = 1;
+}
+
+message GetOrRegisterDatasetResponse {
+  // The id for the registered dataset.
+  int64 dataset_id = 1;
+}
+
+message BeginEpochRequest {
+  // The id of the dataset to iterate over.
+  int64 dataset_id = 1;
+}
+
+message BeginEpochResponse {
+  // The id for the created epoch.
+  int64 epoch_id = 1;
+}
+
+message GetTasksRequest {
+  // The epoch to look up tasks for.
+  int64 epoch_id = 1;
+}
+
+message TaskInfo {
+  // The address of the worker processing the task.
+  string worker_address = 1;
+  // The task id.
+  int64 id = 2;
+}
+
+message GetTasksResponse {
+  // A list of all tasks for an epoch.
+  repeated TaskInfo task_info = 1;
+}
+
+service MasterService {
+  // Registers a worker with the master.
+  rpc RegisterWorker(RegisterWorkerRequest) returns (RegisterWorkerResponse);
+
+  // Registers a dataset with the server, or returns its id if it is already
+  // registered.
+  //
+  // The dataset is constructed in a new graph, so it must not refer to
+  // external resources or variables.
+  rpc GetOrRegisterDataset(GetOrRegisterDatasetRequest)
+      returns (GetOrRegisterDatasetResponse);
+
+  // Begins an epoch over a dataset.
+  rpc BeginEpoch(BeginEpochRequest) returns (BeginEpochResponse);
+
+  // Reports a list of all tasks for an epoch.
+  rpc GetTasks(GetTasksRequest) returns (GetTasksResponse);
+}
diff --git a/tensorflow/core/data/service/worker.proto b/tensorflow/core/data/service/worker.proto
new file mode 100644
index 00000000000..04b8f03474c
--- /dev/null
+++ b/tensorflow/core/data/service/worker.proto
@@ -0,0 +1,31 @@
+syntax = "proto3";
+
+package tensorflow.data;
+
+import "tensorflow/core/data/service/common.proto";
+
+message ProcessTaskRequest {
+  TaskDef task = 1;
+}
+
+message ProcessTaskResponse {}
+
+message GetElementRequest {
+  // The task to fetch an element from.
+  int64 task_id = 1;
+}
+
+message GetElementResponse {
+  // The produced element.
+  CompressedElement compressed_element = 3;
+  // Boolean to indicate whether the iterator has been exhausted.
+  bool end_of_sequence = 2;
+}
+
+service WorkerService {
+  // Processes an task for a dataset, making elements available to clients.
+  rpc ProcessTask(ProcessTaskRequest) returns (ProcessTaskResponse);
+
+  // Gets the next dataset element.
+  rpc GetElement(GetElementRequest) returns (GetElementResponse);
+}