[tf.data service] Export server_lib in public API.

PiperOrigin-RevId: 313669209
Change-Id: Idaf84dc8360a03699e12c7d83511baa4c21d240b
This commit is contained in:
Andrew Audibert 2020-05-28 15:17:20 -07:00 committed by TensorFlower Gardener
parent 7d605fb0e2
commit 4e4a4496b7
14 changed files with 143 additions and 76 deletions

View File

@ -5296,6 +5296,7 @@ py_library(
":variable_scope",
":variables",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/data/experimental/service:server_lib",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/distribute:distribute_coordinator_context",
"//tensorflow/python/distribute:distribute_lib",

View File

@ -1,15 +1,54 @@
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
exports_files(["LICENSE"])
tf_python_pybind_extension(
name = "_pywrap_server_lib",
srcs = ["server_lib_wrapper.cc"],
module_name = "_pywrap_server_lib",
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/data/service:server_lib_headers_lib",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"//third_party/python_runtime:headers",
"@com_github_grpc_grpc//:grpc++_public_hdrs",
"@pybind11",
],
)
py_library(
name = "server_lib",
srcs = ["server_lib.py"],
visibility = [
"//visibility:public",
],
deps = [
":_pywrap_server_lib",
],
)
tf_py_test(
name = "server_lib_test",
srcs = ["server_lib_test.py"],
deps = [
":server_lib",
"//tensorflow/python:platform_test",
],
)
py_library(
name = "service",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
":server_lib",
"//tensorflow/python/data/experimental/ops:data_service_ops",
],
)

View File

@ -19,3 +19,5 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
from tensorflow.python.data.experimental.service.server_lib import MasterServer
from tensorflow.python.data.experimental.service.server_lib import WorkerServer

View File

@ -20,9 +20,11 @@ from __future__ import print_function
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.data.service import _pywrap_server_lib
from tensorflow.python.data.experimental.service import _pywrap_server_lib
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.experimental.service.MasterServer", v1=[])
class MasterServer(object):
"""An in-process tf.data service master server.
@ -30,21 +32,22 @@ class MasterServer(object):
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
register themselves with the master.
```
master_server = tf.data.experimental.service.MasterServer(port=5050)
worker_server = tf.data.experimental.service.WorkerServer(
port=0, master_address="localhost:5050")
dataset = tf.data.Dataset.range(10)
dataset = dataset.apply(tf.data.experimental.service.distribute(
processing_mode="parallel_epochs", service="grpc://localhost:5050"))
```
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> master_address = master.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data master process, use join() to block
indefinitely after starting up the server.
```
master_server = tf.data.experimental.service.MasterServer(port=5050)
master_server.join()
master = tf.data.experimental.service.MasterServer(port=5050)
master.join()
```
"""
@ -72,6 +75,9 @@ class MasterServer(object):
def start(self):
"""Starts this server.
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
>>> master.start()
Raises:
tf.errors.OpError: Or one of its subclasses if an error occurs while
starting the server.
@ -84,8 +90,8 @@ class MasterServer(object):
This is useful when starting a dedicated master process.
```
master_server = tf.data.experimental.service.MasterServer(port=5050)
master_server.join()
master = tf.data.experimental.service.MasterServer(port=5050)
master.join()
```
Raises:
@ -94,6 +100,21 @@ class MasterServer(object):
"""
self._server.join()
@property
def target(self):
"""Returns a target that can be used to connect to the server.
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
The returned string will be in the form protocol://address, e.g.
"grpc://localhost:5050".
"""
return "{0}://localhost:{1}".format(self._protocol,
self._server.bound_port())
def _stop(self):
"""Stops the server.
@ -119,6 +140,7 @@ class MasterServer(object):
return self._server.num_workers()
@tf_export("data.experimental.service.WorkerServer", v1=[])
class WorkerServer(object):
"""An in-process tf.data service worker server.
@ -127,22 +149,23 @@ class WorkerServer(object):
RPC. A worker is associated with a single
`tf.data.experimental.service.MasterServer`.
```
master_server = tf.data.experimental.service.MasterServer(port=5050)
worker_server = tf.data.experimental.service.WorkerServer(
port=0, master_address="localhost:5050")
dataset = tf.data.Dataset.range(10)
dataset = dataset.apply(tf.data.experimental.service.distribute(
processing_mode="parallel_epochs", service="grpc://localhost:5050"))
```
>>> master = tf.data.experimental.service.MasterServer(port=0)
>>> master_address = master.target.split("://")[1]
>>> worker = tf.data.experimental.service.WorkerServer(
... port=0, master_address=master_address)
>>> dataset = tf.data.Dataset.range(10)
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
... processing_mode="parallel_epochs", service=master.target))
>>> print(list(dataset.as_numpy_iterator()))
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
When starting a dedicated tf.data worker process, use join() to block
indefinitely after starting up the server.
```
worker_server = tf.data.experimental.service.WorkerServer(
port=5050, master_address="grpc://localhost:5050")
worker_server.join()
worker = tf.data.experimental.service.WorkerServer(
port=5051, master_address="grpc://localhost:5050")
worker.join()
```
"""
@ -198,7 +221,7 @@ class WorkerServer(object):
```
worker_server = tf.data.experimental.service.WorkerServer(
port=5050, master_address="grpc://localhost:5050")
port=5051, master_address="grpc://localhost:5050")
worker_server.join()
```

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.service import server_lib
from tensorflow.python.data.experimental.service import server_lib
from tensorflow.python.platform import test

View File

@ -1,7 +1,7 @@
# Tests of TensorFlow kernels written using the Python API.
load("//tensorflow:tensorflow.bzl", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load
package(
default_visibility = ["//tensorflow:internal"],
@ -92,8 +92,8 @@ tf_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python/data",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/experimental/service:server_lib",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/service:server_lib",
],
)

View File

@ -23,9 +23,9 @@ from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import data_service_ops
from tensorflow.python.data.experimental.ops import distribute_options
from tensorflow.python.data.experimental.service import server_lib
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.service import server_lib
from tensorflow.python.eager import def_function
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes

View File

@ -1,44 +0,0 @@
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_py_test")
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
tf_python_pybind_extension(
name = "_pywrap_server_lib",
srcs = ["server_lib_wrapper.cc"],
module_name = "_pywrap_server_lib",
deps = [
"//tensorflow/core:lib",
"//tensorflow/core/data/service:server_lib_headers_lib",
"//tensorflow/python:pybind11_lib",
"//tensorflow/python:pybind11_status",
"//third_party/python_runtime:headers",
"@com_github_grpc_grpc//:grpc++_public_hdrs",
"@pybind11",
],
)
py_library(
name = "server_lib",
srcs = ["server_lib.py"],
visibility = [
"//visibility:public",
],
deps = [
":_pywrap_server_lib",
],
)
tf_py_test(
name = "server_lib_test",
srcs = ["server_lib_test.py"],
deps = [
":server_lib",
"//tensorflow/python:platform_test",
],
)

View File

@ -0,0 +1,21 @@
path: "tensorflow.data.experimental.service.MasterServer"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
is_instance: "<type \'object\'>"
member {
name: "target"
mtype: "<type \'property\'>"
}
member_method {
name: "__init__"
argspec: "args=[\'self\', \'port\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
}
member_method {
name: "join"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -0,0 +1,17 @@
path: "tensorflow.data.experimental.service.WorkerServer"
tf_class {
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerServer\'>"
is_instance: "<type \'object\'>"
member_method {
name: "__init__"
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
}
member_method {
name: "join"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
member_method {
name: "start"
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
}
}

View File

@ -1,5 +1,13 @@
path: "tensorflow.data.experimental.service"
tf_module {
member {
name: "MasterServer"
mtype: "<type \'type\'>"
}
member {
name: "WorkerServer"
mtype: "<type \'type\'>"
}
member_method {
name: "distribute"
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "

View File

@ -108,7 +108,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
"//tensorflow/python/compiler:compiler",
"//tensorflow/python:cond_v2",
"//tensorflow/python/data/service:server_lib",
"//tensorflow/python:distributed_framework_test_lib",
"//tensorflow/python/distribute:distribute_test_lib_pip",
"//tensorflow/python:loss_scale",
@ -121,6 +120,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
"//tensorflow/python/data/experimental/ops:testing",
"//tensorflow/python/data/experimental/service:server_lib",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/distribute:combinations",