[tf.data service] Export server_lib in public API.
PiperOrigin-RevId: 313669209 Change-Id: Idaf84dc8360a03699e12c7d83511baa4c21d240b
This commit is contained in:
parent
7d605fb0e2
commit
4e4a4496b7
@ -5296,6 +5296,7 @@ py_library(
|
|||||||
":variable_scope",
|
":variable_scope",
|
||||||
":variables",
|
":variables",
|
||||||
"//tensorflow/core:protos_all_py",
|
"//tensorflow/core:protos_all_py",
|
||||||
|
"//tensorflow/python/data/experimental/service:server_lib",
|
||||||
"//tensorflow/python/data/ops:dataset_ops",
|
"//tensorflow/python/data/ops:dataset_ops",
|
||||||
"//tensorflow/python/distribute:distribute_coordinator_context",
|
"//tensorflow/python/distribute:distribute_coordinator_context",
|
||||||
"//tensorflow/python/distribute:distribute_lib",
|
"//tensorflow/python/distribute:distribute_lib",
|
||||||
|
@ -1,15 +1,54 @@
|
|||||||
|
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
||||||
|
|
||||||
|
# buildifier: disable=same-origin-load
|
||||||
|
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:internal"],
|
default_visibility = ["//tensorflow:internal"],
|
||||||
licenses = ["notice"], # Apache 2.0
|
licenses = ["notice"], # Apache 2.0
|
||||||
)
|
)
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
tf_python_pybind_extension(
|
||||||
|
name = "_pywrap_server_lib",
|
||||||
|
srcs = ["server_lib_wrapper.cc"],
|
||||||
|
module_name = "_pywrap_server_lib",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/data/service:server_lib_headers_lib",
|
||||||
|
"//tensorflow/python:pybind11_lib",
|
||||||
|
"//tensorflow/python:pybind11_status",
|
||||||
|
"//third_party/python_runtime:headers",
|
||||||
|
"@com_github_grpc_grpc//:grpc++_public_hdrs",
|
||||||
|
"@pybind11",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "server_lib",
|
||||||
|
srcs = ["server_lib.py"],
|
||||||
|
visibility = [
|
||||||
|
"//visibility:public",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":_pywrap_server_lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tf_py_test(
|
||||||
|
name = "server_lib_test",
|
||||||
|
srcs = ["server_lib_test.py"],
|
||||||
|
deps = [
|
||||||
|
":server_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name = "service",
|
name = "service",
|
||||||
srcs = ["__init__.py"],
|
srcs = ["__init__.py"],
|
||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
deps = [
|
deps = [
|
||||||
|
":server_lib",
|
||||||
"//tensorflow/python/data/experimental/ops:data_service_ops",
|
"//tensorflow/python/data/experimental/ops:data_service_ops",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -19,3 +19,5 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
from tensorflow.python.data.experimental.ops.data_service_ops import distribute
|
||||||
|
from tensorflow.python.data.experimental.service.server_lib import MasterServer
|
||||||
|
from tensorflow.python.data.experimental.service.server_lib import WorkerServer
|
||||||
|
@ -20,9 +20,11 @@ from __future__ import print_function
|
|||||||
|
|
||||||
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
|
# pylint: disable=invalid-import-order,g-bad-import-order, unused-import
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
from tensorflow.python.data.service import _pywrap_server_lib
|
from tensorflow.python.data.experimental.service import _pywrap_server_lib
|
||||||
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.service.MasterServer", v1=[])
|
||||||
class MasterServer(object):
|
class MasterServer(object):
|
||||||
"""An in-process tf.data service master server.
|
"""An in-process tf.data service master server.
|
||||||
|
|
||||||
@ -30,21 +32,22 @@ class MasterServer(object):
|
|||||||
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
`tf.data.experimental.service.WorkerServer`s. When the workers start, they
|
||||||
register themselves with the master.
|
register themselves with the master.
|
||||||
|
|
||||||
```
|
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
||||||
master_server = tf.data.experimental.service.MasterServer(port=5050)
|
>>> master_address = master.target.split("://")[1]
|
||||||
worker_server = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
port=0, master_address="localhost:5050")
|
... port=0, master_address=master_address)
|
||||||
dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
processing_mode="parallel_epochs", service="grpc://localhost:5050"))
|
... processing_mode="parallel_epochs", service=master.target))
|
||||||
```
|
>>> print(list(dataset.as_numpy_iterator()))
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
When starting a dedicated tf.data master process, use join() to block
|
When starting a dedicated tf.data master process, use join() to block
|
||||||
indefinitely after starting up the server.
|
indefinitely after starting up the server.
|
||||||
|
|
||||||
```
|
```
|
||||||
master_server = tf.data.experimental.service.MasterServer(port=5050)
|
master = tf.data.experimental.service.MasterServer(port=5050)
|
||||||
master_server.join()
|
master.join()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -72,6 +75,9 @@ class MasterServer(object):
|
|||||||
def start(self):
|
def start(self):
|
||||||
"""Starts this server.
|
"""Starts this server.
|
||||||
|
|
||||||
|
>>> master = tf.data.experimental.service.MasterServer(port=0, start=False)
|
||||||
|
>>> master.start()
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
tf.errors.OpError: Or one of its subclasses if an error occurs while
|
||||||
starting the server.
|
starting the server.
|
||||||
@ -84,8 +90,8 @@ class MasterServer(object):
|
|||||||
This is useful when starting a dedicated master process.
|
This is useful when starting a dedicated master process.
|
||||||
|
|
||||||
```
|
```
|
||||||
master_server = tf.data.experimental.service.MasterServer(port=5050)
|
master = tf.data.experimental.service.MasterServer(port=5050)
|
||||||
master_server.join()
|
master.join()
|
||||||
```
|
```
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
@ -94,6 +100,21 @@ class MasterServer(object):
|
|||||||
"""
|
"""
|
||||||
self._server.join()
|
self._server.join()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def target(self):
|
||||||
|
"""Returns a target that can be used to connect to the server.
|
||||||
|
|
||||||
|
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
||||||
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
|
... processing_mode="parallel_epochs", service=master.target))
|
||||||
|
|
||||||
|
The returned string will be in the form protocol://address, e.g.
|
||||||
|
"grpc://localhost:5050".
|
||||||
|
"""
|
||||||
|
return "{0}://localhost:{1}".format(self._protocol,
|
||||||
|
self._server.bound_port())
|
||||||
|
|
||||||
def _stop(self):
|
def _stop(self):
|
||||||
"""Stops the server.
|
"""Stops the server.
|
||||||
|
|
||||||
@ -119,6 +140,7 @@ class MasterServer(object):
|
|||||||
return self._server.num_workers()
|
return self._server.num_workers()
|
||||||
|
|
||||||
|
|
||||||
|
@tf_export("data.experimental.service.WorkerServer", v1=[])
|
||||||
class WorkerServer(object):
|
class WorkerServer(object):
|
||||||
"""An in-process tf.data service worker server.
|
"""An in-process tf.data service worker server.
|
||||||
|
|
||||||
@ -127,22 +149,23 @@ class WorkerServer(object):
|
|||||||
RPC. A worker is associated with a single
|
RPC. A worker is associated with a single
|
||||||
`tf.data.experimental.service.MasterServer`.
|
`tf.data.experimental.service.MasterServer`.
|
||||||
|
|
||||||
```
|
>>> master = tf.data.experimental.service.MasterServer(port=0)
|
||||||
master_server = tf.data.experimental.service.MasterServer(port=5050)
|
>>> master_address = master.target.split("://")[1]
|
||||||
worker_server = tf.data.experimental.service.WorkerServer(
|
>>> worker = tf.data.experimental.service.WorkerServer(
|
||||||
port=0, master_address="localhost:5050")
|
... port=0, master_address=master_address)
|
||||||
dataset = tf.data.Dataset.range(10)
|
>>> dataset = tf.data.Dataset.range(10)
|
||||||
dataset = dataset.apply(tf.data.experimental.service.distribute(
|
>>> dataset = dataset.apply(tf.data.experimental.service.distribute(
|
||||||
processing_mode="parallel_epochs", service="grpc://localhost:5050"))
|
... processing_mode="parallel_epochs", service=master.target))
|
||||||
```
|
>>> print(list(dataset.as_numpy_iterator()))
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
|
||||||
|
|
||||||
When starting a dedicated tf.data worker process, use join() to block
|
When starting a dedicated tf.data worker process, use join() to block
|
||||||
indefinitely after starting up the server.
|
indefinitely after starting up the server.
|
||||||
|
|
||||||
```
|
```
|
||||||
worker_server = tf.data.experimental.service.WorkerServer(
|
worker = tf.data.experimental.service.WorkerServer(
|
||||||
port=5050, master_address="grpc://localhost:5050")
|
port=5051, master_address="grpc://localhost:5050")
|
||||||
worker_server.join()
|
worker.join()
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -198,7 +221,7 @@ class WorkerServer(object):
|
|||||||
|
|
||||||
```
|
```
|
||||||
worker_server = tf.data.experimental.service.WorkerServer(
|
worker_server = tf.data.experimental.service.WorkerServer(
|
||||||
port=5050, master_address="grpc://localhost:5050")
|
port=5051, master_address="grpc://localhost:5050")
|
||||||
worker_server.join()
|
worker_server.join()
|
||||||
```
|
```
|
||||||
|
|
@ -18,7 +18,7 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python.data.service import server_lib
|
from tensorflow.python.data.experimental.service import server_lib
|
||||||
|
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
@ -1,7 +1,7 @@
|
|||||||
# Tests of TensorFlow kernels written using the Python API.
|
# Tests of TensorFlow kernels written using the Python API.
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
load("//tensorflow:tensorflow.bzl", "tf_py_test") # buildifier: disable=same-origin-load
|
||||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load
|
||||||
|
|
||||||
package(
|
package(
|
||||||
default_visibility = ["//tensorflow:internal"],
|
default_visibility = ["//tensorflow:internal"],
|
||||||
@ -92,8 +92,8 @@ tf_py_test(
|
|||||||
"//tensorflow/python:platform_test",
|
"//tensorflow/python:platform_test",
|
||||||
"//tensorflow/python/data",
|
"//tensorflow/python/data",
|
||||||
"//tensorflow/python/data/experimental/ops:testing",
|
"//tensorflow/python/data/experimental/ops:testing",
|
||||||
|
"//tensorflow/python/data/experimental/service:server_lib",
|
||||||
"//tensorflow/python/data/kernel_tests:test_base",
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
"//tensorflow/python/data/service:server_lib",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,9 +23,9 @@ from absl.testing import parameterized
|
|||||||
|
|
||||||
from tensorflow.python.data.experimental.ops import data_service_ops
|
from tensorflow.python.data.experimental.ops import data_service_ops
|
||||||
from tensorflow.python.data.experimental.ops import distribute_options
|
from tensorflow.python.data.experimental.ops import distribute_options
|
||||||
|
from tensorflow.python.data.experimental.service import server_lib
|
||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.data.service import server_lib
|
|
||||||
from tensorflow.python.eager import def_function
|
from tensorflow.python.eager import def_function
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
@ -1,44 +0,0 @@
|
|||||||
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
|
|
||||||
|
|
||||||
# buildifier: disable=same-origin-load
|
|
||||||
load("//tensorflow:tensorflow.bzl", "tf_py_test")
|
|
||||||
|
|
||||||
package(
|
|
||||||
default_visibility = ["//tensorflow:internal"],
|
|
||||||
licenses = ["notice"], # Apache 2.0
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_python_pybind_extension(
|
|
||||||
name = "_pywrap_server_lib",
|
|
||||||
srcs = ["server_lib_wrapper.cc"],
|
|
||||||
module_name = "_pywrap_server_lib",
|
|
||||||
deps = [
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
"//tensorflow/core/data/service:server_lib_headers_lib",
|
|
||||||
"//tensorflow/python:pybind11_lib",
|
|
||||||
"//tensorflow/python:pybind11_status",
|
|
||||||
"//third_party/python_runtime:headers",
|
|
||||||
"@com_github_grpc_grpc//:grpc++_public_hdrs",
|
|
||||||
"@pybind11",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name = "server_lib",
|
|
||||||
srcs = ["server_lib.py"],
|
|
||||||
visibility = [
|
|
||||||
"//visibility:public",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":_pywrap_server_lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tf_py_test(
|
|
||||||
name = "server_lib_test",
|
|
||||||
srcs = ["server_lib_test.py"],
|
|
||||||
deps = [
|
|
||||||
":server_lib",
|
|
||||||
"//tensorflow/python:platform_test",
|
|
||||||
],
|
|
||||||
)
|
|
@ -0,0 +1,21 @@
|
|||||||
|
path: "tensorflow.data.experimental.service.MasterServer"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.MasterServer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member {
|
||||||
|
name: "target"
|
||||||
|
mtype: "<type \'property\'>"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'port\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "join"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "start"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -0,0 +1,17 @@
|
|||||||
|
path: "tensorflow.data.experimental.service.WorkerServer"
|
||||||
|
tf_class {
|
||||||
|
is_instance: "<class \'tensorflow.python.data.experimental.service.server_lib.WorkerServer\'>"
|
||||||
|
is_instance: "<type \'object\'>"
|
||||||
|
member_method {
|
||||||
|
name: "__init__"
|
||||||
|
argspec: "args=[\'self\', \'port\', \'master_address\', \'worker_address\', \'protocol\', \'start\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'True\'], "
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "join"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
member_method {
|
||||||
|
name: "start"
|
||||||
|
argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
|
||||||
|
}
|
||||||
|
}
|
@ -1,5 +1,13 @@
|
|||||||
path: "tensorflow.data.experimental.service"
|
path: "tensorflow.data.experimental.service"
|
||||||
tf_module {
|
tf_module {
|
||||||
|
member {
|
||||||
|
name: "MasterServer"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
|
member {
|
||||||
|
name: "WorkerServer"
|
||||||
|
mtype: "<type \'type\'>"
|
||||||
|
}
|
||||||
member_method {
|
member_method {
|
||||||
name: "distribute"
|
name: "distribute"
|
||||||
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
argspec: "args=[\'processing_mode\', \'service\', \'job_name\', \'max_outstanding_requests\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
|
||||||
|
@ -108,7 +108,6 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
|
"//tensorflow/python/autograph/pyct/common_transformers:common_transformers",
|
||||||
"//tensorflow/python/compiler:compiler",
|
"//tensorflow/python/compiler:compiler",
|
||||||
"//tensorflow/python:cond_v2",
|
"//tensorflow/python:cond_v2",
|
||||||
"//tensorflow/python/data/service:server_lib",
|
|
||||||
"//tensorflow/python:distributed_framework_test_lib",
|
"//tensorflow/python:distributed_framework_test_lib",
|
||||||
"//tensorflow/python/distribute:distribute_test_lib_pip",
|
"//tensorflow/python/distribute:distribute_test_lib_pip",
|
||||||
"//tensorflow/python:loss_scale",
|
"//tensorflow/python:loss_scale",
|
||||||
@ -121,6 +120,7 @@ COMMON_PIP_DEPS = [
|
|||||||
"//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
|
"//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
|
||||||
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
|
"//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
|
||||||
"//tensorflow/python/data/experimental/ops:testing",
|
"//tensorflow/python/data/experimental/ops:testing",
|
||||||
|
"//tensorflow/python/data/experimental/service:server_lib",
|
||||||
"//tensorflow/python/data/kernel_tests:test_base",
|
"//tensorflow/python/data/kernel_tests:test_base",
|
||||||
"//tensorflow/python/debug:debug_pip",
|
"//tensorflow/python/debug:debug_pip",
|
||||||
"//tensorflow/python/distribute:combinations",
|
"//tensorflow/python/distribute:combinations",
|
||||||
|
Loading…
Reference in New Issue
Block a user